{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4e0b23d8",
   "metadata": {},
   "source": [
    "# Reproducibility Package for Optimizing Economic Complexity\n",
    "\n",
    "\n",
    "This notebook can be used to reproduce the results of the manuscript \"Optimizing Economic Complexity\" by Viktor Stojkoski and César A. Hidalgo"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "78990528",
   "metadata": {},
   "source": [
    "## Upload required packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6af94e98",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from scipy.linalg import eig\n",
    "from scipy.stats import zscore\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.io import loadmat\n",
    "from scipy.optimize import linprog\n",
    "from scipy.optimize import minimize\n",
    "from scipy.stats import norm\n",
    "import statsmodels.api as sm\n",
    "import statsmodels.formula.api as smf\n",
    "from numpy.linalg import inv\n",
    "from itertools import product\n",
    "import ecioptimization as eciopt\n",
    "\n",
    "from statsmodels.iolib.summary2 import summary_col\n",
    "from sklearn.metrics import mean_absolute_error, f1_score\n",
    "from statsmodels.formula.api import ols\n",
    "import seaborn as sns\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22e7782a",
   "metadata": {},
   "source": [
    "## Calculate complexity of Cars"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30d8cd46",
   "metadata": {},
   "outputs": [],
   "source": [
    "start_year = 2002\n",
    "end_year = 2022\n",
    "hs_code = 8703\n",
    "\n",
    "\n",
    "filename = 'trade_data/bilateral_' + str(start_year) +'.csv'\n",
    "data_start = pd.read_csv(filename)\n",
    "\n",
    "X_start = data_start.pivot_table(values='value', index='exporter_id', columns='hs_code', \n",
    "                             aggfunc=np.sum, fill_value=0)  # Replace 'TradeValues' with the exact column name in your CSV file\n",
    "\n",
    "# Remove rows with sum less than 10^9\n",
    "X_start = X_start[X_start.sum(axis=1) >= 10**9]\n",
    "\n",
    "# Remove columns with sum less than 500,000\n",
    "X_start = X_start.loc[:, X_start.sum(axis=0) >= 500000]\n",
    "\n",
    "countries_start = X_start.index.tolist()\n",
    "categories_start = X_start.columns.tolist()\n",
    "\n",
    "\n",
    "filename = 'trade_data/bilateral_' + str(end_year) +'.csv'\n",
    "data_end = pd.read_csv(filename)\n",
    "\n",
    "X_end = data_end.pivot_table(values='value', index='exporter_id', columns='hs_code', \n",
    "                             aggfunc=np.sum, fill_value=0)  \n",
    "\n",
    "# Remove rows with sum less than 10^9\n",
    "X_end = X_end[X_end.sum(axis=1) >= 10**9]\n",
    "\n",
    "# Remove columns with sum less than 500,000\n",
    "X_end = X_end.loc[:, X_end.sum(axis=0) >= 500000]\n",
    "\n",
    "countries_end = X_end.index.tolist()\n",
    "categories_end = X_end.columns.tolist()\n",
    "\n",
    "# Find the intersection of countries and categories\n",
    "countries_probit = np.intersect1d(countries_start, countries_end)\n",
    "categories_probit = np.intersect1d(categories_start, categories_end)\n",
    "\n",
    "# Function to find indices\n",
    "def find_indices(original, to_find):\n",
    "    return [original.index(item) for item in to_find if item in original]\n",
    "\n",
    "\n",
    "# Finding indices\n",
    "z_countries_start = find_indices(countries_start, countries_probit)\n",
    "z_countries_end = find_indices(countries_end, countries_probit)\n",
    "\n",
    "z_categories_start = find_indices(categories_start, categories_probit)\n",
    "z_categories_end = find_indices(categories_end, categories_probit)\n",
    "\n",
    "# Creating matrices for start, mid, and end years using .iloc\n",
    "X_mat_start = X_start.iloc[z_countries_start, z_categories_start]\n",
    "X_mat_end = X_end.iloc[z_countries_end, z_categories_end]\n",
    "\n",
    "\n",
    "# Assuming rca and cplex_rank functions are already defined in Python\n",
    "RCA_start = eciopt.rca(X_mat_start)\n",
    "RCA_end = eciopt.rca(X_mat_end)\n",
    "\n",
    "M_start = (RCA_start > 1).astype(float)\n",
    "M_end = (RCA_end > 1).astype(float)\n",
    "\n",
    "# Run cplex_rank function for start and end matrices\n",
    "_, ProductRankings_start, _, _ = eciopt.cplex_rank(M_start, countries_probit, categories_probit)\n",
    "\n",
    "# Run cplex_rank function for start and end matrices\n",
    "_, ProductRankings_end, _, _ = eciopt.cplex_rank(M_end, countries_probit, categories_probit)\n",
    "\n",
    "# Sort ProductRankings_start by 'PCI' in descending order and reset the index\n",
    "ProductRankings_start_sorted = ProductRankings_start.sort_values(by='PCI', ascending=False).reset_index(drop=True)\n",
    "total_activities_start = len(ProductRankings_start_sorted)\n",
    "\n",
    "# Find the PCI value and rank for the specified hs_code in the start year\n",
    "pci_value_start = ProductRankings_start_sorted.loc[ProductRankings_start_sorted['Product'] == hs_code, 'PCI'].values[0]\n",
    "rank_start = ProductRankings_start_sorted[ProductRankings_start_sorted['Product'] == hs_code].index[0] + 1  # Rank is index + 1\n",
    "\n",
    "# Print the PCI, rank, and total number of activities for the start year\n",
    "print(f\"PCI of Cars (HS Code: {hs_code}) in {start_year} is {pci_value_start}, ranked {rank_start} out of {total_activities_start} activities.\")\n",
    "\n",
    "# Sort ProductRankings_end by 'PCI' in descending order and reset the index\n",
    "ProductRankings_end_sorted = ProductRankings_end.sort_values(by='PCI', ascending=False).reset_index(drop=True)\n",
    "total_activities_end = len(ProductRankings_end_sorted)\n",
    "\n",
    "# Find the PCI value and rank for the specified hs_code in the end year\n",
    "pci_value_end = ProductRankings_end_sorted.loc[ProductRankings_end_sorted['Product'] == hs_code, 'PCI'].values[0]\n",
    "rank_end = ProductRankings_end_sorted[ProductRankings_end_sorted['Product'] == hs_code].index[0] + 1  # Rank is index + 1\n",
    "\n",
    "# Print the PCI, rank, and total number of activities for the end year\n",
    "print(f\"PCI of Cars (HS Code: {hs_code}) in {end_year} is {pci_value_end}, ranked {rank_end} out of {total_activities_end} activities.\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa834c83",
   "metadata": {},
   "source": [
    "## Estimate National level models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42da80a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize lists to store the rows\n",
    "beta_params_entry_list = []\n",
    "beta_std_entry_list = []\n",
    "beta_pvalues_entry_list = []\n",
    "beta_rsquared_entry_list = []\n",
    "\n",
    "beta_params_exit_list = []\n",
    "beta_std_exit_list = []\n",
    "beta_pvalues_exit_list = []\n",
    "beta_rsquared_exit_list = []\n",
    "\n",
    "# Define the range of delta_t and tau\n",
    "min_delta_t = 8\n",
    "max_delta_t = 15  # Adjust as needed\n",
    "min_tau = 3\n",
    "\n",
    "# Loop over possible delta_t values\n",
    "for delta_t in range(min_delta_t, max_delta_t + 1):\n",
    "    max_tau = delta_t - 3  # tau must be less than delta_t\n",
    "    # Loop over possible tau values\n",
    "    for tau in range(min_tau, max_tau + 1):\n",
    "        # Calculate the maximum start_year so that end_year does not exceed 2022\n",
    "        max_start_year = 2022 - delta_t\n",
    "        # Loop over possible start_year values\n",
    "        for start_year in range(1999, max_start_year + 1):\n",
    "            mid_year = start_year + tau\n",
    "            end_year = start_year + delta_t\n",
    "\n",
    "            # Read data for start_year\n",
    "            filename = 'trade_data/bilateral_' + str(start_year) + '.csv'\n",
    "            try:\n",
    "                data_start = pd.read_csv(filename)\n",
    "            except FileNotFoundError:\n",
    "                print(f\"File not found: {filename}\")\n",
    "                continue\n",
    "\n",
    "            X_start = data_start.pivot_table(values='value', index='exporter_id', columns='hs_code', \n",
    "                                             aggfunc=np.sum, fill_value=0)\n",
    "\n",
    "            # Remove rows with sum less than 10^9\n",
    "            X_start = X_start[X_start.sum(axis=1) >= 10**9]\n",
    "\n",
    "            # Remove columns with sum less than 500,000\n",
    "            X_start = X_start.loc[:, X_start.sum(axis=0) >= 500000]\n",
    "\n",
    "            countries_start = X_start.index.tolist()\n",
    "            categories_start = X_start.columns.tolist()\n",
    "\n",
    "            # Read data for mid_year\n",
    "            filename = 'trade_data/bilateral_' + str(mid_year) + '.csv'\n",
    "            try:\n",
    "                data_mid = pd.read_csv(filename)\n",
    "            except FileNotFoundError:\n",
    "                print(f\"File not found: {filename}\")\n",
    "                continue\n",
    "\n",
    "            X_mid = data_mid.pivot_table(values='value', index='exporter_id', columns='hs_code', \n",
    "                                         aggfunc=np.sum, fill_value=0)\n",
    "\n",
    "            # Remove rows with sum less than 10^9\n",
    "            X_mid = X_mid[X_mid.sum(axis=1) >= 10**9]\n",
    "\n",
    "            # Remove columns with sum less than 500,000\n",
    "            X_mid = X_mid.loc[:, X_mid.sum(axis=0) >= 500000]\n",
    "\n",
    "            countries_mid = X_mid.index.tolist()\n",
    "            categories_mid = X_mid.columns.tolist()\n",
    "\n",
    "            # Read data for end_year\n",
    "            filename = 'trade_data/bilateral_' + str(end_year) + '.csv'\n",
    "            try:\n",
    "                data_end = pd.read_csv(filename)\n",
    "            except FileNotFoundError:\n",
    "                print(f\"File not found: {filename}\")\n",
    "                continue\n",
    "\n",
    "            X_end = data_end.pivot_table(values='value', index='exporter_id', columns='hs_code', \n",
    "                                         aggfunc=np.sum, fill_value=0)\n",
    "\n",
    "            # Remove rows with sum less than 10^9\n",
    "            X_end = X_end[X_end.sum(axis=1) >= 10**9]\n",
    "\n",
    "            # Remove columns with sum less than 500,000\n",
    "            X_end = X_end.loc[:, X_end.sum(axis=0) >= 500000]\n",
    "\n",
    "            countries_end = X_end.index.tolist()\n",
    "            categories_end = X_end.columns.tolist()\n",
    "\n",
    "            # Find the intersection of countries and categories\n",
    "            countries_probit = np.intersect1d(countries_start, countries_mid)\n",
    "            countries_probit = np.intersect1d(countries_probit, countries_end)\n",
    "            categories_probit = np.intersect1d(categories_start, categories_mid)\n",
    "            categories_probit = np.intersect1d(categories_probit, categories_end)\n",
    "\n",
    "            # Skip iteration if no common countries or categories\n",
    "            if len(countries_probit) == 0 or len(categories_probit) == 0:\n",
    "                continue\n",
    "\n",
    "            # Function to find indices\n",
    "            def find_indices(original, to_find):\n",
    "                return [original.index(item) for item in to_find if item in original]\n",
    "\n",
    "            # Finding indices\n",
    "            z_countries_start = find_indices(countries_start, countries_probit)\n",
    "            z_countries_mid = find_indices(countries_mid, countries_probit)\n",
    "            z_countries_end = find_indices(countries_end, countries_probit)\n",
    "\n",
    "            z_categories_start = find_indices(categories_start, categories_probit)\n",
    "            z_categories_mid = find_indices(categories_mid, categories_probit)\n",
    "            z_categories_end = find_indices(categories_end, categories_probit)\n",
    "\n",
    "            # Creating matrices for start, mid, and end years using .iloc\n",
    "            X_mat_start = X_start.iloc[z_countries_start, z_categories_start]\n",
    "            X_mat_mid = X_mid.iloc[z_countries_mid, z_categories_mid]\n",
    "            X_mat_end = X_end.iloc[z_countries_end, z_categories_end]\n",
    "\n",
    "            # Assuming rca and cplex_rank functions are already defined\n",
    "            RCA_start = eciopt.rca(X_mat_start)\n",
    "            RCA_mid = eciopt.rca(X_mat_mid)\n",
    "            RCA_end = eciopt.rca(X_mat_end)\n",
    "\n",
    "            M_start = (RCA_start > 1).astype(float)\n",
    "       \n",
    "            # Run cplex_rank function for start matrices\n",
    "            try:\n",
    "                _, _, Relatedness_start, _ = eciopt.cplex_rank(M_start, countries_probit, categories_probit)\n",
    "            except Exception as e:\n",
    "                print(f\"Error in cplex_rank for start_year {start_year}, delta_t {delta_t}, tau {tau}: {e}\")\n",
    "                continue\n",
    "\n",
    "            # Flatten the Relatedness_start matrix\n",
    "            Relatedness_start = Relatedness_start.flatten()\n",
    "\n",
    "            # Create repeated arrays for countries and products\n",
    "            countries_all = np.repeat(countries_probit, len(categories_probit))\n",
    "            products_all = np.tile(categories_probit, len(countries_probit))\n",
    "\n",
    "            RR = RCA_start.flatten()\n",
    "            RR_mid = RCA_mid.flatten()\n",
    "            RR_end = RCA_end.flatten()\n",
    "\n",
    "            # Create a DataFrame for analysis\n",
    "            probit_data = pd.DataFrame({\n",
    "                'countries_all': countries_all,\n",
    "                'products_all': products_all,\n",
    "                'Relatedness_start': Relatedness_start,\n",
    "                'RCA_start': np.log(1 + RR),\n",
    "                'RCA_mid': np.log(1 + RR_mid),\n",
    "                'RCA_end': np.log(1 + RR_end)\n",
    "            })\n",
    "\n",
    "            # Split the data into 'entry' and 'exit' subsets\n",
    "            entry_data = probit_data[probit_data['RCA_start'] < np.log(2)].copy()\n",
    "            exit_data = probit_data[probit_data['RCA_start'] >= np.log(2)].copy()\n",
    "\n",
    "            # Skip if there are not enough data points\n",
    "            if entry_data.empty or exit_data.empty:\n",
    "                continue\n",
    "\n",
    "            # Calculate the z-score of Relatedness_start separately for each subset\n",
    "            entry_data['Relative_relatedness_start'] = entry_data.groupby('countries_all')['Relatedness_start'].transform(zscore)\n",
    "            exit_data['Relative_relatedness_start'] = exit_data.groupby('countries_all')['Relatedness_start'].transform(zscore)\n",
    "\n",
    "            # Define the model formula\n",
    "            formula = 'RCA_end ~ RCA_mid + RCA_start + Relatedness_start + Relative_relatedness_start'\n",
    "\n",
    "            # Fit the model for the 'entry' subset\n",
    "            try:\n",
    "                model_entry = ols(formula=formula, data=entry_data)\n",
    "                result_entry = model_entry.fit()\n",
    "            except Exception as e:\n",
    "                print(f\"Error fitting entry model for start_year {start_year}, delta_t {delta_t}, tau {tau}: {e}\")\n",
    "                continue\n",
    "\n",
    "            # Fit the model for the 'exit' subset\n",
    "            try:\n",
    "                model_exit = ols(formula=formula, data=exit_data)\n",
    "                result_exit = model_exit.fit()\n",
    "            except Exception as e:\n",
    "                print(f\"Error fitting exit model for start_year {start_year}, delta_t {delta_t}, tau {tau}: {e}\")\n",
    "                continue\n",
    "\n",
    "            # Extract parameters, standard errors, and p-values\n",
    "            beta_params_entry = result_entry.params\n",
    "            beta_std_entry = result_entry.bse\n",
    "            beta_pvalues_entry = result_entry.pvalues\n",
    "            beta_rsquared_entry = result_entry.rsquared\n",
    "\n",
    "            beta_params_exit = result_exit.params\n",
    "            beta_std_exit = result_exit.bse\n",
    "            beta_pvalues_exit = result_exit.pvalues\n",
    "            beta_rsquared_exit = result_exit.rsquared            \n",
    "\n",
    "            # Prepare rows to append to the lists\n",
    "            row_params_entry = {'start_year': start_year, 'mid_year': mid_year, 'end_year': end_year, 'delta_t': delta_t, 'tau': tau}\n",
    "            row_std_entry = {'start_year': start_year, 'mid_year': mid_year, 'end_year': end_year, 'delta_t': delta_t, 'tau': tau}\n",
    "            row_pvalues_entry = {'start_year': start_year, 'mid_year': mid_year, 'end_year': end_year, 'delta_t': delta_t, 'tau': tau}\n",
    "            row_rsquared_entry = {'start_year': start_year, 'mid_year': mid_year, 'end_year': end_year, 'delta_t': delta_t, 'tau': tau, 'R2': beta_rsquared_entry}\n",
    "            \n",
    "            row_params_exit = {'start_year': start_year, 'mid_year': mid_year, 'end_year': end_year, 'delta_t': delta_t, 'tau': tau}\n",
    "            row_std_exit = {'start_year': start_year, 'mid_year': mid_year, 'end_year': end_year, 'delta_t': delta_t, 'tau': tau}\n",
    "            row_pvalues_exit = {'start_year': start_year, 'mid_year': mid_year, 'end_year': end_year, 'delta_t': delta_t, 'tau': tau}\n",
    "            row_rsquared_exit = {'start_year': start_year, 'mid_year': mid_year, 'end_year': end_year, 'delta_t': delta_t, 'tau': tau, 'R2': beta_rsquared_exit}\n",
    "            \n",
    "            # Add parameters to the rows\n",
    "            for param_name in beta_params_entry.index:\n",
    "                row_params_entry['beta_' + param_name] = beta_params_entry[param_name]\n",
    "                row_std_entry['beta_' + param_name] = beta_std_entry[param_name]\n",
    "                row_pvalues_entry['beta_' + param_name] = beta_pvalues_entry[param_name]\n",
    "\n",
    "            for param_name in beta_params_exit.index:\n",
    "                row_params_exit['beta_' + param_name] = beta_params_exit[param_name]\n",
    "                row_std_exit['beta_' + param_name] = beta_std_exit[param_name]\n",
    "                row_pvalues_exit['beta_' + param_name] = beta_pvalues_exit[param_name]\n",
    "\n",
    "            # Append the rows to the lists\n",
    "            beta_params_entry_list.append(row_params_entry)\n",
    "            beta_std_entry_list.append(row_std_entry)\n",
    "            beta_pvalues_entry_list.append(row_pvalues_entry)\n",
    "            beta_rsquared_entry_list.append(row_rsquared_entry)            \n",
    "\n",
    "            beta_params_exit_list.append(row_params_exit)\n",
    "            beta_std_exit_list.append(row_std_exit)\n",
    "            beta_pvalues_exit_list.append(row_pvalues_exit)\n",
    "            beta_rsquared_exit_list.append(row_rsquared_exit)            \n",
    "            \n",
    "            display(\"Start year:\", start_year)\n",
    "            display(\"Mid year:\", mid_year)\n",
    "            display(\"End year:\", end_year)\n",
    "\n",
    "# After the loops, create DataFrames from the lists\n",
    "BETA_params_country_entry = pd.DataFrame(beta_params_entry_list)\n",
    "BETA_std_country_entry = pd.DataFrame(beta_std_entry_list)\n",
    "BETA_pvalue_country_entry = pd.DataFrame(beta_pvalues_entry_list)\n",
    "BETA_rsquared_country_entry = pd.DataFrame(beta_rsquared_entry_list)\n",
    "\n",
    "\n",
    "BETA_params_country_exit = pd.DataFrame(beta_params_exit_list)\n",
    "BETA_std_country_exit = pd.DataFrame(beta_std_exit_list)\n",
    "BETA_pvalue_country_exit = pd.DataFrame(beta_pvalues_exit_list)\n",
    "BETA_rsquared_country_exit = pd.DataFrame(beta_rsquared_exit_list)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e162be50",
   "metadata": {},
   "source": [
    "### Figure S1 NATIONAL LEVEL: Heatmaps for the dependence of the ENTRY model coefficients on \\Delta t and \\tau\n",
    "\n",
    "### Figure S2. NATIONAL LEVEL: Heatmaps showing the coefficient of determination as a function of \\Delta t and \\tau"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba2ae882",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ensure delta_t and tau are in all DataFrames\n",
    "for df in [BETA_params_country_entry, BETA_std_country_entry, BETA_pvalue_country_entry]:\n",
    "    df['delta_t'] = df['end_year'] - df['start_year']\n",
    "    df['tau'] = df['mid_year'] - df['start_year']\n",
    "    df['delta_t'] = df['delta_t'].astype(int)\n",
    "    df['tau'] = df['tau'].astype(int)\n",
    "\n",
    "# Exclude 'beta_Intercept' from beta columns\n",
    "beta_columns = [col for col in BETA_params_country_entry.columns if col.startswith('beta_') and col != 'beta_Intercept']\n",
    "\n",
    "# Mapping from beta column names to LaTeX labels\n",
    "latex_labels = {\n",
    "    'beta_RCA_mid': 'Coefficient of $R_{cp}(t+\\\\tau)$',\n",
    "    'beta_RCA_start': 'Coefficient of $R_{cp}(t)$',\n",
    "    'beta_Relatedness_start': 'Coefficient of $\\\\omega_{cp}(t)$',\n",
    "    'beta_Relative_relatedness_start': 'Coefficient of $\\\\tilde{\\\\omega}_{cp}(t)$'\n",
    "}\n",
    "\n",
    "# Update delta_t and tau labels for axes using LaTeX\n",
    "delta_t_label = '$\\\\Delta t$'\n",
    "tau_label = '$\\\\tau$'\n",
    "\n",
    "# Group by delta_t and tau and calculate the mean of beta parameters\n",
    "grouped_params = BETA_params_country_entry.groupby(['delta_t', 'tau'])[beta_columns].mean().reset_index()\n",
    "grouped_std = BETA_std_country_entry.groupby(['delta_t', 'tau'])[beta_columns].mean().reset_index()\n",
    "grouped_pvalues = BETA_pvalue_country_entry.groupby(['delta_t', 'tau'])[beta_columns].mean().reset_index()\n",
    "\n",
    "# Number of beta parameters to plot (excluding intercept)\n",
    "num_beta_params = len(beta_columns)\n",
    "\n",
    "# Create a figure with subplots\n",
    "fig, axes = plt.subplots(nrows=num_beta_params, ncols=3, figsize=(18, 4 * num_beta_params))\n",
    "\n",
    "# If only one beta parameter, ensure axes is a 2D array\n",
    "if num_beta_params == 1:\n",
    "    axes = np.array([axes])\n",
    "\n",
    "# For each beta parameter, create pivot tables and plot heatmaps\n",
    "for i, beta_col in enumerate(beta_columns):\n",
    "    # Pivot tables for beta parameter, standard error, and p-value\n",
    "    pivot_params = grouped_params.pivot(index='tau', columns='delta_t', values=beta_col).sort_index().sort_index(axis=1)\n",
    "    pivot_std = grouped_std.pivot(index='tau', columns='delta_t', values=beta_col).sort_index().sort_index(axis=1)\n",
    "    pivot_pvalues = grouped_pvalues.pivot(index='tau', columns='delta_t', values=beta_col).sort_index().sort_index(axis=1)\n",
    "    \n",
    "    # Get the LaTeX label for the beta parameter\n",
    "    beta_label = latex_labels.get(beta_col, beta_col)\n",
    "    \n",
    "    # Plot beta parameter heatmap\n",
    "    sns.heatmap(pivot_params, annot=True, fmt=\".2f\", cmap='coolwarm', cbar_kws={'label': beta_label}, ax=axes[i, 0])\n",
    "    axes[i, 0].set_title(beta_label)\n",
    "    axes[i, 0].set_ylabel(tau_label)\n",
    "    axes[i, 0].set_xlabel(delta_t_label)\n",
    "    \n",
    "    # Plot standard error heatmap\n",
    "    sns.heatmap(pivot_std, annot=True, fmt=\".2f\", cmap='viridis', cbar_kws={'label': 'Std Error'}, ax=axes[i, 1])\n",
    "    axes[i, 1].set_title(f'Std Error of {beta_label}')\n",
    "    axes[i, 1].set_ylabel(tau_label)\n",
    "    axes[i, 1].set_xlabel(delta_t_label)\n",
    "    \n",
    "    # Plot p-value heatmap\n",
    "    sns.heatmap(pivot_pvalues, annot=True, fmt=\".2f\", cmap='magma_r', cbar_kws={'label': 'p-value'}, ax=axes[i, 2])\n",
    "    axes[i, 2].set_title(f'p-value of {beta_label}')\n",
    "    axes[i, 2].set_ylabel(tau_label)\n",
    "    axes[i, 2].set_xlabel(delta_t_label)\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig('figure_s1.png', dpi=300)\n",
    "plt.show()\n",
    "\n",
    "\n",
    "# Assuming R² values have been computed and added to the DataFrame\n",
    "# Add a column 'R2' to each DataFrame, if not already calculated\n",
    "BETA_rsquared_country_entry['delta_t'] = BETA_params_country_entry['delta_t']\n",
    "BETA_rsquared_country_entry['tau'] = BETA_params_country_entry['tau']\n",
    "\n",
    "# Group by delta_t and tau to get the mean R² value\n",
    "grouped_R2 = BETA_rsquared_country_entry.groupby(['delta_t', 'tau'])['R2'].mean().reset_index()\n",
    "\n",
    "# Pivot table for R² values\n",
    "pivot_R2 = grouped_R2.pivot(index='tau', columns='delta_t', values='R2').sort_index().sort_index(axis=1)\n",
    "\n",
    "# Create the heatmap\n",
    "plt.figure(figsize=(8, 6))\n",
    "sns.heatmap(pivot_R2, annot=True, fmt=\".2f\", cmap='coolwarm', cbar_kws={'label': 'Average $R^2$'})\n",
    "plt.title('Average $R^2$ of entry models')\n",
    "plt.ylabel(tau_label)\n",
    "plt.xlabel(delta_t_label)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('figure_s2.png', dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a064cea6",
   "metadata": {},
   "source": [
    "### Figure S3 NATIONAL LEVEL:  Heatmaps for the dependence of the EXIT model coefficients on \\Delta t and \\tau\n",
    "\n",
    "### Figure S4 NATIONAL LEVELS. Heatmaps showing the coefficient of determination as a function of \\Delta t and \\tau"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "557a80e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ensure delta_t and tau are in all DataFrames\n",
    "for df in [BETA_params_country_exit, BETA_std_country_exit, BETA_pvalue_country_exit]:\n",
    "    df['delta_t'] = df['end_year'] - df['start_year']\n",
    "    df['tau'] = df['mid_year'] - df['start_year']\n",
    "    df['delta_t'] = df['delta_t'].astype(int)\n",
    "    df['tau'] = df['tau'].astype(int)\n",
    "\n",
    "# Exclude 'beta_Intercept' from beta columns\n",
    "beta_columns = [col for col in BETA_params_country_exit.columns if col.startswith('beta_') and col != 'beta_Intercept']\n",
    "\n",
    "# Mapping from beta column names to LaTeX labels\n",
    "latex_labels = {\n",
    "    'beta_RCA_mid': 'Coefficient of $R_{cp}(t+\\\\tau)$',\n",
    "    'beta_RCA_start': 'Coefficient of $R_{cp}(t)$',\n",
    "    'beta_Relatedness_start': 'Coefficient of $\\\\omega_{cp}(t)$',\n",
    "    'beta_Relative_relatedness_start': 'Coefficient of $\\\\tilde{\\\\omega}_{cp}(t)$'\n",
    "}\n",
    "\n",
    "# Update delta_t and tau labels for axes using LaTeX\n",
    "delta_t_label = '$\\\\Delta t$'\n",
    "tau_label = '$\\\\tau$'\n",
    "\n",
    "# Group by delta_t and tau and calculate the mean of beta parameters\n",
    "grouped_params = BETA_params_country_exit.groupby(['delta_t', 'tau'])[beta_columns].mean().reset_index()\n",
    "grouped_std = BETA_std_country_exit.groupby(['delta_t', 'tau'])[beta_columns].mean().reset_index()\n",
    "grouped_pvalues = BETA_pvalue_country_exit.groupby(['delta_t', 'tau'])[beta_columns].mean().reset_index()\n",
    "\n",
    "# Number of beta parameters to plot (excluding intercept)\n",
    "num_beta_params = len(beta_columns)\n",
    "\n",
    "# Create a figure with subplots\n",
    "fig, axes = plt.subplots(nrows=num_beta_params, ncols=3, figsize=(18, 4 * num_beta_params))\n",
    "\n",
    "# If only one beta parameter, ensure axes is a 2D array\n",
    "if num_beta_params == 1:\n",
    "    axes = np.array([axes])\n",
    "\n",
    "# For each beta parameter, create pivot tables and plot heatmaps\n",
    "for i, beta_col in enumerate(beta_columns):\n",
    "    # Pivot tables for beta parameter, standard error, and p-value\n",
    "    pivot_params = grouped_params.pivot(index='tau', columns='delta_t', values=beta_col).sort_index().sort_index(axis=1)\n",
    "    pivot_std = grouped_std.pivot(index='tau', columns='delta_t', values=beta_col).sort_index().sort_index(axis=1)\n",
    "    pivot_pvalues = grouped_pvalues.pivot(index='tau', columns='delta_t', values=beta_col).sort_index().sort_index(axis=1)\n",
    "    \n",
    "    # Get the LaTeX label for the beta parameter\n",
    "    beta_label = latex_labels.get(beta_col, beta_col)\n",
    "    \n",
    "    # Plot beta parameter heatmap\n",
    "    sns.heatmap(pivot_params, annot=True, fmt=\".2f\", cmap='coolwarm', cbar_kws={'label': beta_label}, ax=axes[i, 0])\n",
    "    axes[i, 0].set_title(beta_label)\n",
    "    axes[i, 0].set_ylabel(tau_label)\n",
    "    axes[i, 0].set_xlabel(delta_t_label)\n",
    "    \n",
    "    # Plot standard error heatmap\n",
    "    sns.heatmap(pivot_std, annot=True, fmt=\".2f\", cmap='viridis', cbar_kws={'label': 'Std Error'}, ax=axes[i, 1])\n",
    "    axes[i, 1].set_title(f'Std Error of {beta_label}')\n",
    "    axes[i, 1].set_ylabel(tau_label)\n",
    "    axes[i, 1].set_xlabel(delta_t_label)\n",
    "    \n",
    "    # Plot p-value heatmap\n",
    "    sns.heatmap(pivot_pvalues, annot=True, fmt=\".2f\", cmap='magma_r', cbar_kws={'label': 'p-value'}, ax=axes[i, 2])\n",
    "    axes[i, 2].set_title(f'p-value of {beta_label}')\n",
    "    axes[i, 2].set_ylabel(tau_label)\n",
    "    axes[i, 2].set_xlabel(delta_t_label)\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig('figure_s3.png', dpi=300)\n",
    "plt.show()\n",
    "\n",
    "\n",
    "# Assuming R² values have been computed and added to the DataFrame\n",
    "# Add a column 'R2' to each DataFrame, if not already calculated\n",
    "BETA_rsquared_country_exit['delta_t'] = BETA_params_country_exit['delta_t']\n",
    "BETA_rsquared_country_exit['tau'] = BETA_params_country_exit['tau']\n",
    "\n",
    "# Group by delta_t and tau to get the mean R² value\n",
    "grouped_R2 = BETA_rsquared_country_exit.groupby(['delta_t', 'tau'])['R2'].mean().reset_index()\n",
    "\n",
    "# Pivot table for R² values\n",
    "pivot_R2 = grouped_R2.pivot(index='tau', columns='delta_t', values='R2').sort_index().sort_index(axis=1)\n",
    "\n",
    "# Create the heatmap\n",
    "plt.figure(figsize=(8, 6))\n",
    "sns.heatmap(pivot_R2, annot=True, fmt=\".2f\", cmap='coolwarm', cbar_kws={'label': 'Average $R^2$'})\n",
    "plt.title('Average $R^2$ of exit models')\n",
    "plt.ylabel(tau_label)\n",
    "plt.xlabel(delta_t_label)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('figure_s4.png', dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0c6a8b0",
   "metadata": {},
   "source": [
    "### Figure S5 NATIONAL LEVEL:  Histograms for the model coefficients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e52d6d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the delta_t and tau values\n",
    "delta_t_value = 10\n",
    "tau_value = 5\n",
    "\n",
    "# Filter the DataFrames for delta_t = 10 and tau = 5\n",
    "filtered_params_entry = BETA_params_country_entry[\n",
    "    (BETA_params_country_entry['delta_t'] == delta_t_value) &\n",
    "    (BETA_params_country_entry['tau'] == tau_value)\n",
    "]\n",
    "\n",
    "filtered_params_exit = BETA_params_country_exit[\n",
    "    (BETA_params_country_exit['delta_t'] == delta_t_value) &\n",
    "    (BETA_params_country_exit['tau'] == tau_value)\n",
    "]\n",
    "\n",
    "# Exclude 'beta_Intercept' from beta columns\n",
    "beta_columns_entry = [col for col in filtered_params_entry.columns if col.startswith('beta_') and col != 'beta_Intercept']\n",
    "beta_columns_exit = [col for col in filtered_params_exit.columns if col.startswith('beta_') and col != 'beta_Intercept']\n",
    "\n",
    "# Ensure that beta columns are the same for both models\n",
    "beta_columns = list(set(beta_columns_entry) & set(beta_columns_exit))\n",
    "\n",
    "# If there are no common beta columns, print a message and exit\n",
    "if not beta_columns:\n",
    "    print(\"No common beta coefficients found between entry and exit models.\")\n",
    "else:\n",
    "    # Mapping from beta column names to LaTeX labels\n",
    "    latex_labels = {\n",
    "        'beta_RCA_mid': '$R_{cp}(t+\\\\tau)$',\n",
    "        'beta_RCA_start': '$R_{cp}(t)$',\n",
    "        'beta_Relatedness_start': '$\\\\omega_{cp}(t)$',\n",
    "        'beta_Relative_relatedness_start': '$\\\\tilde{\\\\omega}_{cp}(t)$'\n",
    "    }\n",
    "\n",
    "    # Set up the plotting style\n",
    "    sns.set(style='whitegrid')\n",
    "\n",
    "    # Number of beta parameters\n",
    "    num_beta_params = len(beta_columns)\n",
    "\n",
    "    # Create subplots with nrows=num_beta_params and ncols=2\n",
    "    fig, axes = plt.subplots(nrows=num_beta_params, ncols=2, figsize=(14, 4 * num_beta_params))\n",
    "\n",
    "    # If only one beta parameter, axes may not be a 2D array, so reshape\n",
    "    if num_beta_params == 1:\n",
    "        axes = np.array([axes])\n",
    "\n",
    "    # Plot histograms for each beta parameter\n",
    "    for i, beta_col in enumerate(beta_columns):\n",
    "        beta_label = latex_labels.get(beta_col, beta_col)\n",
    "\n",
    "        # Entry model histogram\n",
    "        ax_entry = axes[i, 0]\n",
    "        if not filtered_params_entry.empty:\n",
    "            sns.histplot(filtered_params_entry[beta_col], kde=True, bins=10, color='skyblue', ax=ax_entry)\n",
    "            ax_entry.set_title(f'Entry Model: Coefficient for {beta_label}')\n",
    "            ax_entry.set_xlabel('Coefficient Value')\n",
    "            ax_entry.set_ylabel('Frequency')\n",
    "        else:\n",
    "            ax_entry.text(0.5, 0.5, 'No data available', ha='center', va='center')\n",
    "            ax_entry.set_title(f'Entry Model: Coefficient for {beta_label}')\n",
    "            ax_entry.set_xlabel('Coefficient Value')\n",
    "            ax_entry.set_ylabel('Frequency')\n",
    "\n",
    "        # Exit model histogram\n",
    "        ax_exit = axes[i, 1]\n",
    "        if not filtered_params_exit.empty:\n",
    "            sns.histplot(filtered_params_exit[beta_col], kde=True, bins=10, color='salmon', ax=ax_exit)\n",
    "            ax_exit.set_title(f'Exit Model: Coefficient for {beta_label}')\n",
    "            ax_exit.set_xlabel('Coefficient Value')\n",
    "            ax_exit.set_ylabel('Frequency')\n",
    "        else:\n",
    "            ax_exit.text(0.5, 0.5, 'No data available', ha='center', va='center')\n",
    "            ax_exit.set_title(f'Exit Model: Coefficient for {beta_label}')\n",
    "            ax_exit.set_xlabel('Coefficient Value')\n",
    "            ax_exit.set_ylabel('Frequency')\n",
    "\n",
    "    # Adjust layout\n",
    "    plt.tight_layout()\n",
    "\n",
    "    # Save the figure as a high-quality PNG\n",
    "    plt.savefig('figure_s5.png', dpi=300, bbox_inches='tight')\n",
    "\n",
    "    # Display the figure\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a17f9eb2",
   "metadata": {},
   "source": [
    "### NATIONAL LEVEL: Calculate coefficients used in the optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fd45f00",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the delta_t and tau values\n",
    "delta_t_value = 10\n",
    "tau_value = 5\n",
    "\n",
    "# Filter the DataFrames for delta_t = 10 and tau = 5\n",
    "filtered_params_entry = BETA_params_country_entry[\n",
    "    (BETA_params_country_entry['delta_t'] == delta_t_value) &\n",
    "    (BETA_params_country_entry['tau'] == tau_value)\n",
    "]\n",
    "\n",
    "filtered_params_exit = BETA_params_country_exit[\n",
    "    (BETA_params_country_exit['delta_t'] == delta_t_value) &\n",
    "    (BETA_params_country_exit['tau'] == tau_value)\n",
    "]\n",
    "\n",
    "# Exclude 'beta_Intercept' from beta columns\n",
    "beta_columns_entry = [col for col in filtered_params_entry.columns if col.startswith('beta_')]\n",
    "beta_columns_exit = [col for col in filtered_params_exit.columns if col.startswith('beta_')]\n",
    "\n",
    "# Check if there is data to compute the average\n",
    "if not filtered_params_entry.empty and beta_columns_entry:\n",
    "    # Compute the mean of the beta coefficients for the entry model\n",
    "    beta_country_entry = filtered_params_entry[beta_columns_entry].mean()\n",
    "    print(\"Average beta coefficients for entry model (delta_t = 10, tau = 5):\")\n",
    "    print(beta_country_entry)\n",
    "else:\n",
    "    print(\"No data available for entry model with delta_t = 10 and tau = 5.\")\n",
    "\n",
    "if not filtered_params_exit.empty and beta_columns_exit:\n",
    "    # Compute the mean of the beta coefficients for the exit model\n",
    "    beta_country_exit = filtered_params_exit[beta_columns_exit].mean()\n",
    "    print(\"\\nAverage beta coefficients for exit model (delta_t = 10, tau = 5):\")\n",
    "    print(beta_country_exit)\n",
    "else:\n",
    "    print(\"No data available for exit model with delta_t = 10 and tau = 5.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09b782c5",
   "metadata": {},
   "source": [
    "## Estimate MSA level models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98ee7b71",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize lists to store the rows\n",
    "beta_params_entry_list = []\n",
    "beta_std_entry_list = []\n",
    "beta_pvalues_entry_list = []\n",
    "beta_rsquared_entry_list = []\n",
    "\n",
    "beta_params_exit_list = []\n",
    "beta_std_exit_list = []\n",
    "beta_pvalues_exit_list = []\n",
    "beta_rsquared_exit_list = []\n",
    "\n",
    "# Define the range of delta_t and tau\n",
    "min_delta_t = 8\n",
    "max_delta_t = 15  # Adjust as needed\n",
    "min_tau = 3\n",
    "\n",
    "# Loop over possible delta_t values\n",
    "for delta_t in range(min_delta_t, max_delta_t + 1):\n",
    "    max_tau = delta_t - 3  # tau must be less than delta_t\n",
    "    # Loop over possible tau values\n",
    "    for tau in range(min_tau, max_tau + 1):\n",
    "        # Calculate the maximum start_year so that end_year does not exceed 2022\n",
    "        max_start_year = 2022 - delta_t\n",
    "        # Loop over possible start_year values\n",
    "        for start_year in range(2003, max_start_year + 1):\n",
    "            mid_year = start_year + tau\n",
    "            end_year = start_year + delta_t\n",
    "\n",
    "            # Read data for start_year\n",
    "            filename = 'msa_data/msa_data_' + str(start_year) +'.csv'\n",
    "            data_start = pd.read_csv(filename)\n",
    "\n",
    "            X_start = data_start.pivot_table(values='ap_avg', index='msa', columns='naics', \n",
    "                                                  aggfunc=np.sum, fill_value=0)  # Replace 'TradeValues' with the exact column name in your CSV file\n",
    "\n",
    "            # Remove rows with sum less than 0\n",
    "            X_start = X_start[X_start.sum(axis=1) > 10**5]\n",
    "\n",
    "            # Remove columns with sum less than 0\n",
    "            X_start = X_start.loc[:, X_start.sum(axis=0) > 10*5.5]\n",
    "            \n",
    "            msa_start = X_start.index.tolist()\n",
    "            categories_start = X_start.columns.tolist()\n",
    "\n",
    "            filename = 'msa_data/msa_data_' + str(mid_year) +'.csv'\n",
    "            data_mid = pd.read_csv(filename)\n",
    "\n",
    "            X_mid = data_mid.pivot_table(values='ap_avg', index='msa', columns='naics', \n",
    "                                                  aggfunc=np.sum, fill_value=0)  \n",
    "\n",
    "            \n",
    "            # Remove rows with sum less than 0\n",
    "            X_mid = X_mid[X_mid.sum(axis=1) > 10**5]\n",
    "\n",
    "            # Remove columns with sum less than 0\n",
    "            X_mid = X_mid.loc[:, X_mid.sum(axis=0) > 10*5.5]\n",
    "            \n",
    "            msa_mid = X_mid.index.tolist()\n",
    "            categories_mid = X_mid.columns.tolist()\n",
    "\n",
    "            filename = 'msa_data/msa_data_' + str(end_year) +'.csv'\n",
    "            data_end = pd.read_csv(filename)\n",
    "\n",
    "            X_end = data_end.pivot_table(values='ap_avg', index='msa', columns='naics', \n",
    "                                                  aggfunc=np.sum, fill_value=0)  \n",
    "            # Remove rows with sum less than 0\n",
    "            X_end = X_end[X_end.sum(axis=1) > 10**5]\n",
    "\n",
    "            # Remove columns with sum less than 0\n",
    "            X_end = X_end.loc[:, X_end.sum(axis=0) > 10**5.5]\n",
    "                        \n",
    "\n",
    "\n",
    "            msa_end = X_end.index.tolist()\n",
    "            categories_end = X_end.columns.tolist()\n",
    "\n",
    "            # Find the intersection of msa and categories\n",
    "            msa_probit = np.intersect1d(msa_start, msa_mid)\n",
    "            msa_probit = np.intersect1d(msa_probit, msa_end)\n",
    "            categories_probit = np.intersect1d(categories_start, categories_mid)\n",
    "            categories_probit = np.intersect1d(categories_probit, categories_end)\n",
    "\n",
    "            # Skip iteration if no common msa or categories\n",
    "            if len(msa_probit) == 0 or len(categories_probit) == 0:\n",
    "                continue\n",
    "\n",
    "            # Function to find indices\n",
    "            def find_indices(original, to_find):\n",
    "                return [original.index(item) for item in to_find if item in original]\n",
    "\n",
    "            # Finding indices\n",
    "            z_msa_start = find_indices(msa_start, msa_probit)\n",
    "            z_msa_mid = find_indices(msa_mid, msa_probit)\n",
    "            z_msa_end = find_indices(msa_end, msa_probit)\n",
    "\n",
    "            z_categories_start = find_indices(categories_start, categories_probit)\n",
    "            z_categories_mid = find_indices(categories_mid, categories_probit)\n",
    "            z_categories_end = find_indices(categories_end, categories_probit)\n",
    "\n",
    "            # Creating matrices for start, mid, and end years using .iloc\n",
    "            X_mat_start = X_start.iloc[z_msa_start, z_categories_start]\n",
    "            X_mat_mid = X_mid.iloc[z_msa_mid, z_categories_mid]\n",
    "            X_mat_end = X_end.iloc[z_msa_end, z_categories_end]\n",
    "            \n",
    "            # Remove rows with sum less than 0\n",
    "            X_mat_start = X_mat_start[X_mat_end.sum(axis=1) > 0]\n",
    "            # Remove columns with sum less than 0\n",
    "            X_mat_start = X_mat_start.loc[:, X_mat_end.sum(axis=0) > 0]\n",
    "\n",
    "            # Remove rows with sum less than 0\n",
    "            X_mat_mid = X_mat_mid[X_mat_end.sum(axis=1) > 0]\n",
    "            # Remove columns with sum less than 0\n",
    "            X_mat_mid = X_mat_mid.loc[:, X_mat_end.sum(axis=0) > 0]\n",
    "\n",
    "            # Remove rows with sum less than 0\n",
    "            X_mat_end = X_mat_end[X_mat_end.sum(axis=1) > 0]\n",
    "            # Remove columns with sum less than 0\n",
    "            X_mat_end = X_mat_end.loc[:, X_mat_end.sum(axis=0) > 0]\n",
    "\n",
    "            msa_probit = X_mat_end.index.tolist()\n",
    "            categories_probit = X_mat_end.columns.tolist()            \n",
    "\n",
    "            # Assuming rca and cplex_rank functions are already defined\n",
    "            RCA_start = eciopt.rca(X_mat_start)\n",
    "            RCA_mid = eciopt.rca(X_mat_mid)\n",
    "            RCA_end = eciopt.rca(X_mat_end)\n",
    "\n",
    "            M_start = (RCA_start > 1).astype(float)\n",
    "            M_mid = (RCA_mid > 1).astype(float)\n",
    "            M_end = (RCA_end > 1).astype(float)\n",
    "\n",
    "            # Run cplex_rank function for start matrices\n",
    "            try:\n",
    "                _, _, Relatedness_start, _ = eciopt.cplex_rank(M_start, msa_probit, categories_probit)\n",
    "            except Exception as e:\n",
    "                print(f\"Error in cplex_rank for start_year {start_year}, delta_t {delta_t}, tau {tau}: {e}\")\n",
    "                continue\n",
    "\n",
    "            # Flatten the Relatedness_start matrix\n",
    "            Relatedness_start = Relatedness_start.flatten()\n",
    "\n",
    "            # Create repeated arrays for msa and products\n",
    "            msa_all = np.repeat(msa_probit, len(categories_probit))\n",
    "            products_all = np.tile(categories_probit, len(msa_probit))\n",
    "\n",
    "            RR = RCA_start.flatten()\n",
    "            RR_mid = RCA_mid.flatten()\n",
    "            RR_end = RCA_end.flatten()\n",
    "\n",
    "            # Create a DataFrame for analysis\n",
    "            probit_data = pd.DataFrame({\n",
    "                'msa_all': msa_all,\n",
    "                'products_all': products_all,\n",
    "                'M_end': M_end.flatten(),\n",
    "                'Relatedness_start': Relatedness_start,\n",
    "                'RCA_start': np.log(1 + RR),\n",
    "                'RCA_mid': np.log(1 + RR_mid),\n",
    "                'RCA_end': np.log(1 + RR_end)\n",
    "            })\n",
    "\n",
    "            # Split the data into 'entry' and 'exit' subsets\n",
    "            entry_data = probit_data[probit_data['RCA_start'] < np.log(2)].copy()\n",
    "            exit_data = probit_data[probit_data['RCA_start'] >= np.log(2)].copy()\n",
    "\n",
    "            # Skip if there are not enough data points\n",
    "            if entry_data.empty or exit_data.empty:\n",
    "                continue\n",
    "\n",
    "            # Calculate the z-score of Relatedness_start separately for each subset\n",
    "            entry_data['Relative_relatedness_start'] = entry_data.groupby('msa_all')['Relatedness_start'].transform(zscore)\n",
    "            exit_data['Relative_relatedness_start'] = exit_data.groupby('msa_all')['Relatedness_start'].transform(zscore)\n",
    "\n",
    "            # Define the model formula\n",
    "            formula = 'RCA_end ~ RCA_mid + RCA_start + Relatedness_start + Relative_relatedness_start'\n",
    "\n",
    "            # Fit the model for the 'entry' subset\n",
    "            try:\n",
    "                model_entry = ols(formula=formula, data=entry_data)\n",
    "                result_entry = model_entry.fit()\n",
    "            except Exception as e:\n",
    "                print(f\"Error fitting entry model for start_year {start_year}, delta_t {delta_t}, tau {tau}: {e}\")\n",
    "                continue\n",
    "\n",
    "            # Fit the model for the 'exit' subset\n",
    "            try:\n",
    "                model_exit = ols(formula=formula, data=exit_data)\n",
    "                result_exit = model_exit.fit()\n",
    "            except Exception as e:\n",
    "                print(f\"Error fitting exit model for start_year {start_year}, delta_t {delta_t}, tau {tau}: {e}\")\n",
    "                continue\n",
    "\n",
    "             # Extract parameters, standard errors, and p-values\n",
    "            beta_params_entry = result_entry.params\n",
    "            beta_std_entry = result_entry.bse\n",
    "            beta_pvalues_entry = result_entry.pvalues\n",
    "            beta_rsquared_entry = result_entry.rsquared\n",
    "\n",
    "            beta_params_exit = result_exit.params\n",
    "            beta_std_exit = result_exit.bse\n",
    "            beta_pvalues_exit = result_exit.pvalues\n",
    "            beta_rsquared_exit = result_exit.rsquared            \n",
    "\n",
    "            # Prepare rows to append to the lists\n",
    "            row_params_entry = {'start_year': start_year, 'mid_year': mid_year, 'end_year': end_year, 'delta_t': delta_t, 'tau': tau}\n",
    "            row_std_entry = {'start_year': start_year, 'mid_year': mid_year, 'end_year': end_year, 'delta_t': delta_t, 'tau': tau}\n",
    "            row_pvalues_entry = {'start_year': start_year, 'mid_year': mid_year, 'end_year': end_year, 'delta_t': delta_t, 'tau': tau}\n",
    "            row_rsquared_entry = {'start_year': start_year, 'mid_year': mid_year, 'end_year': end_year, 'delta_t': delta_t, 'tau': tau, 'R2': beta_rsquared_entry}\n",
    "            \n",
    "            row_params_exit = {'start_year': start_year, 'mid_year': mid_year, 'end_year': end_year, 'delta_t': delta_t, 'tau': tau}\n",
    "            row_std_exit = {'start_year': start_year, 'mid_year': mid_year, 'end_year': end_year, 'delta_t': delta_t, 'tau': tau}\n",
    "            row_pvalues_exit = {'start_year': start_year, 'mid_year': mid_year, 'end_year': end_year, 'delta_t': delta_t, 'tau': tau}\n",
    "            row_rsquared_exit = {'start_year': start_year, 'mid_year': mid_year, 'end_year': end_year, 'delta_t': delta_t, 'tau': tau, 'R2': beta_rsquared_exit}\n",
    "            \n",
    "            # Add parameters to the rows\n",
    "            for param_name in beta_params_entry.index:\n",
    "                row_params_entry['beta_' + param_name] = beta_params_entry[param_name]\n",
    "                row_std_entry['beta_' + param_name] = beta_std_entry[param_name]\n",
    "                row_pvalues_entry['beta_' + param_name] = beta_pvalues_entry[param_name]\n",
    "\n",
    "            for param_name in beta_params_exit.index:\n",
    "                row_params_exit['beta_' + param_name] = beta_params_exit[param_name]\n",
    "                row_std_exit['beta_' + param_name] = beta_std_exit[param_name]\n",
    "                row_pvalues_exit['beta_' + param_name] = beta_pvalues_exit[param_name]\n",
    "\n",
    "            # Append the rows to the lists\n",
    "            beta_params_entry_list.append(row_params_entry)\n",
    "            beta_std_entry_list.append(row_std_entry)\n",
    "            beta_pvalues_entry_list.append(row_pvalues_entry)\n",
    "            beta_rsquared_entry_list.append(row_rsquared_entry)            \n",
    "\n",
    "            beta_params_exit_list.append(row_params_exit)\n",
    "            beta_std_exit_list.append(row_std_exit)\n",
    "            beta_pvalues_exit_list.append(row_pvalues_exit)\n",
    "            beta_rsquared_exit_list.append(row_rsquared_exit)            \n",
    "            \n",
    "            display(\"Start year:\", start_year)\n",
    "            display(\"Mid year:\", mid_year)\n",
    "            display(\"End year:\", end_year)\n",
    "\n",
    "# After the loops, create DataFrames from the lists\n",
    "BETA_params_msa_entry = pd.DataFrame(beta_params_entry_list)\n",
    "BETA_std_msa_entry = pd.DataFrame(beta_std_entry_list)\n",
    "BETA_pvalue_msa_entry = pd.DataFrame(beta_pvalues_entry_list)\n",
    "BETA_rsquared_msa_entry = pd.DataFrame(beta_rsquared_entry_list)\n",
    "\n",
    "\n",
    "BETA_params_msa_exit = pd.DataFrame(beta_params_exit_list)\n",
    "BETA_std_msa_exit = pd.DataFrame(beta_std_exit_list)\n",
    "BETA_pvalue_msa_exit = pd.DataFrame(beta_pvalues_exit_list)\n",
    "BETA_rsquared_msa_exit = pd.DataFrame(beta_rsquared_exit_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18b6e1fe",
   "metadata": {},
   "source": [
    "### Figure S6 MSA LEVEL: Heatmaps for the dependence of the ENTRY model coefficients on \\Delta t and \\tau\n",
    "\n",
    "### Figure S7 MSA LEVEL. Heatmaps showing the coefficient of determination of the ENTRY models as a function of  \\Delta t and \\tau"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bb5e977",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ensure delta_t and tau are in all DataFrames\n",
    "for df in [BETA_params_msa_entry, BETA_std_msa_entry, BETA_pvalue_msa_entry]:\n",
    "    df['delta_t'] = df['end_year'] - df['start_year']\n",
    "    df['tau'] = df['mid_year'] - df['start_year']\n",
    "    df['delta_t'] = df['delta_t'].astype(int)\n",
    "    df['tau'] = df['tau'].astype(int)\n",
    "\n",
    "# Exclude 'beta_Intercept' from beta columns\n",
    "beta_columns = [col for col in BETA_params_msa_entry.columns if col.startswith('beta_') and col != 'beta_Intercept']\n",
    "\n",
    "# Mapping from beta column names to LaTeX labels\n",
    "latex_labels = {\n",
    "    'beta_RCA_mid': 'Coefficient of $R_{cp}(t+\\\\tau)$',\n",
    "    'beta_RCA_start': 'Coefficient of $R_{cp}(t)$',\n",
    "    'beta_Relatedness_start': 'Coefficient of $\\\\omega_{cp}(t)$',\n",
    "    'beta_Relative_relatedness_start': 'Coefficient of $\\\\tilde{\\\\omega}_{cp}(t)$'\n",
    "}\n",
    "\n",
    "# Update delta_t and tau labels for axes using LaTeX\n",
    "delta_t_label = '$\\\\Delta t$'\n",
    "tau_label = '$\\\\tau$'\n",
    "\n",
    "# Group by delta_t and tau and calculate the mean of beta parameters\n",
    "grouped_params = BETA_params_msa_entry.groupby(['delta_t', 'tau'])[beta_columns].mean().reset_index()\n",
    "grouped_std = BETA_std_msa_entry.groupby(['delta_t', 'tau'])[beta_columns].mean().reset_index()\n",
    "grouped_pvalues = BETA_pvalue_msa_entry.groupby(['delta_t', 'tau'])[beta_columns].mean().reset_index()\n",
    "\n",
    "# Number of beta parameters to plot (excluding intercept)\n",
    "num_beta_params = len(beta_columns)\n",
    "\n",
    "# Create a figure with subplots\n",
    "fig, axes = plt.subplots(nrows=num_beta_params, ncols=3, figsize=(18, 4 * num_beta_params))\n",
    "\n",
    "# If only one beta parameter, ensure axes is a 2D array\n",
    "if num_beta_params == 1:\n",
    "    axes = np.array([axes])\n",
    "\n",
    "# For each beta parameter, create pivot tables and plot heatmaps\n",
    "for i, beta_col in enumerate(beta_columns):\n",
    "    # Pivot tables for beta parameter, standard error, and p-value\n",
    "    pivot_params = grouped_params.pivot(index='tau', columns='delta_t', values=beta_col).sort_index().sort_index(axis=1)\n",
    "    pivot_std = grouped_std.pivot(index='tau', columns='delta_t', values=beta_col).sort_index().sort_index(axis=1)\n",
    "    pivot_pvalues = grouped_pvalues.pivot(index='tau', columns='delta_t', values=beta_col).sort_index().sort_index(axis=1)\n",
    "    \n",
    "    # Get the LaTeX label for the beta parameter\n",
    "    beta_label = latex_labels.get(beta_col, beta_col)\n",
    "    \n",
    "    # Plot beta parameter heatmap\n",
    "    sns.heatmap(pivot_params, annot=True, fmt=\".2f\", cmap='coolwarm', cbar_kws={'label': beta_label}, ax=axes[i, 0])\n",
    "    axes[i, 0].set_title(beta_label)\n",
    "    axes[i, 0].set_ylabel(tau_label)\n",
    "    axes[i, 0].set_xlabel(delta_t_label)\n",
    "    \n",
    "    # Plot standard error heatmap\n",
    "    sns.heatmap(pivot_std, annot=True, fmt=\".2f\", cmap='viridis', cbar_kws={'label': 'Std Error'}, ax=axes[i, 1])\n",
    "    axes[i, 1].set_title(f'Std Error of {beta_label}')\n",
    "    axes[i, 1].set_ylabel(tau_label)\n",
    "    axes[i, 1].set_xlabel(delta_t_label)\n",
    "    \n",
    "    # Plot p-value heatmap\n",
    "    sns.heatmap(pivot_pvalues, annot=True, fmt=\".2f\", cmap='magma_r', cbar_kws={'label': 'p-value'}, ax=axes[i, 2])\n",
    "    axes[i, 2].set_title(f'p-value of {beta_label}')\n",
    "    axes[i, 2].set_ylabel(tau_label)\n",
    "    axes[i, 2].set_xlabel(delta_t_label)\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig('figure_s6.png', dpi=300)\n",
    "plt.show()\n",
    "\n",
    "# Assuming R² values have been computed and added to the DataFrame\n",
    "# Add a column 'R2' to each DataFrame, if not already calculated\n",
    "BETA_rsquared_msa_entry['delta_t'] = BETA_params_msa_entry['delta_t']\n",
    "BETA_rsquared_msa_entry['tau'] = BETA_params_msa_entry['tau']\n",
    "\n",
    "# Group by delta_t and tau to get the mean R² value\n",
    "grouped_R2 = BETA_rsquared_msa_entry.groupby(['delta_t', 'tau'])['R2'].mean().reset_index()\n",
    "\n",
    "# Pivot table for R² values\n",
    "pivot_R2 = grouped_R2.pivot(index='tau', columns='delta_t', values='R2').sort_index().sort_index(axis=1)\n",
    "\n",
    "# Create the heatmap\n",
    "plt.figure(figsize=(8, 6))\n",
    "sns.heatmap(pivot_R2, annot=True, fmt=\".2f\", cmap='coolwarm', cbar_kws={'label': 'Average $R^2$'})\n",
    "plt.title('Average $R^2$ of entry models')\n",
    "plt.ylabel(tau_label)\n",
    "plt.xlabel(delta_t_label)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('figure_s7.png', dpi=300)\n",
    "plt.show()\n",
    "\n",
    "print(pivot_R2.mean().mean())\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6485fd6",
   "metadata": {},
   "source": [
    "### Figure S8 MSA LEVEL:  Heatmaps for the dependence of the EXIT model coefficients on \\Delta t and \\tau\n",
    "\n",
    "### Figure S9 MSA LEVEL: Heatmaps showing the coefficient of determination of the EXIT models as a function of  \\Delta t and \\tau\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae0f61d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ensure delta_t and tau are in all DataFrames\n",
    "for df in [BETA_params_msa_exit, BETA_std_msa_exit, BETA_pvalue_msa_exit]:\n",
    "    df['delta_t'] = df['end_year'] - df['start_year']\n",
    "    df['tau'] = df['mid_year'] - df['start_year']\n",
    "    df['delta_t'] = df['delta_t'].astype(int)\n",
    "    df['tau'] = df['tau'].astype(int)\n",
    "\n",
    "# Exclude 'beta_Intercept' from beta columns\n",
    "beta_columns = [col for col in BETA_params_msa_exit.columns if col.startswith('beta_') and col != 'beta_Intercept']\n",
    "\n",
    "# Mapping from beta column names to LaTeX labels\n",
    "latex_labels = {\n",
    "    'beta_RCA_mid': 'Coefficient of $R_{cp}(t+\\\\tau)$',\n",
    "    'beta_RCA_start': 'Coefficient of $R_{cp}(t)$',\n",
    "    'beta_Relatedness_start': 'Coefficient of $\\\\omega_{cp}(t)$',\n",
    "    'beta_Relative_relatedness_start': 'Coefficient of $\\\\tilde{\\\\omega}_{cp}(t)$'\n",
    "}\n",
    "\n",
    "# Update delta_t and tau labels for axes using LaTeX\n",
    "delta_t_label = '$\\\\Delta t$'\n",
    "tau_label = '$\\\\tau$'\n",
    "\n",
    "# Group by delta_t and tau and calculate the mean of beta parameters\n",
    "grouped_params = BETA_params_msa_exit.groupby(['delta_t', 'tau'])[beta_columns].mean().reset_index()\n",
    "grouped_std = BETA_std_msa_exit.groupby(['delta_t', 'tau'])[beta_columns].mean().reset_index()\n",
    "grouped_pvalues = BETA_pvalue_msa_exit.groupby(['delta_t', 'tau'])[beta_columns].mean().reset_index()\n",
    "\n",
    "# Number of beta parameters to plot (excluding intercept)\n",
    "num_beta_params = len(beta_columns)\n",
    "\n",
    "# Create a figure with subplots\n",
    "fig, axes = plt.subplots(nrows=num_beta_params, ncols=3, figsize=(18, 4 * num_beta_params))\n",
    "\n",
    "# If only one beta parameter, ensure axes is a 2D array\n",
    "if num_beta_params == 1:\n",
    "    axes = np.array([axes])\n",
    "\n",
    "# For each beta parameter, create pivot tables and plot heatmaps\n",
    "for i, beta_col in enumerate(beta_columns):\n",
    "    # Pivot tables for beta parameter, standard error, and p-value\n",
    "    pivot_params = grouped_params.pivot(index='tau', columns='delta_t', values=beta_col).sort_index().sort_index(axis=1)\n",
    "    pivot_std = grouped_std.pivot(index='tau', columns='delta_t', values=beta_col).sort_index().sort_index(axis=1)\n",
    "    pivot_pvalues = grouped_pvalues.pivot(index='tau', columns='delta_t', values=beta_col).sort_index().sort_index(axis=1)\n",
    "    \n",
    "    # Get the LaTeX label for the beta parameter\n",
    "    beta_label = latex_labels.get(beta_col, beta_col)\n",
    "    \n",
    "    # Plot beta parameter heatmap\n",
    "    sns.heatmap(pivot_params, annot=True, fmt=\".2f\", cmap='coolwarm', cbar_kws={'label': beta_label}, ax=axes[i, 0])\n",
    "    axes[i, 0].set_title(beta_label)\n",
    "    axes[i, 0].set_ylabel(tau_label)\n",
    "    axes[i, 0].set_xlabel(delta_t_label)\n",
    "    \n",
    "    # Plot standard error heatmap\n",
    "    sns.heatmap(pivot_std, annot=True, fmt=\".2f\", cmap='viridis', cbar_kws={'label': 'Std Error'}, ax=axes[i, 1])\n",
    "    axes[i, 1].set_title(f'Std Error of {beta_label}')\n",
    "    axes[i, 1].set_ylabel(tau_label)\n",
    "    axes[i, 1].set_xlabel(delta_t_label)\n",
    "    \n",
    "    # Plot p-value heatmap\n",
    "    sns.heatmap(pivot_pvalues, annot=True, fmt=\".2f\", cmap='magma_r', cbar_kws={'label': 'p-value'}, ax=axes[i, 2])\n",
    "    axes[i, 2].set_title(f'p-value of {beta_label}')\n",
    "    axes[i, 2].set_ylabel(tau_label)\n",
    "    axes[i, 2].set_xlabel(delta_t_label)\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig('figure_s8.png', dpi=300)\n",
    "plt.show()\n",
    "\n",
    "# Assuming R² values have been computed and added to the DataFrame\n",
    "# Add a column 'R2' to each DataFrame, if not already calculated\n",
    "BETA_rsquared_msa_exit['delta_t'] = BETA_params_msa_exit['delta_t']\n",
    "BETA_rsquared_msa_exit['tau'] = BETA_params_msa_exit['tau']\n",
    "\n",
    "# Group by delta_t and tau to get the mean R² value\n",
    "grouped_R2 = BETA_rsquared_msa_exit.groupby(['delta_t', 'tau'])['R2'].mean().reset_index()\n",
    "\n",
    "# Pivot table for R² values\n",
    "pivot_R2 = grouped_R2.pivot(index='tau', columns='delta_t', values='R2').sort_index().sort_index(axis=1)\n",
    "\n",
    "# Create the heatmap\n",
    "plt.figure(figsize=(8, 6))\n",
    "sns.heatmap(pivot_R2, annot=True, fmt=\".2f\", cmap='coolwarm', cbar_kws={'label': 'Average $R^2$'})\n",
    "plt.title('Average $R^2$ of exit models')\n",
    "plt.ylabel(tau_label)\n",
    "plt.xlabel(delta_t_label)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('figure_s9.png', dpi=300)\n",
    "plt.show()\n",
    "\n",
    "\n",
    "print(pivot_R2.mean().mean())\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a04b9d94",
   "metadata": {},
   "source": [
    "\n",
    "### Figure S10 MSA LEVEL:  Histograms for the model coefficients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfe9e93c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the delta_t and tau values\n",
    "delta_t_value = 10\n",
    "tau_value = 5\n",
    "\n",
    "# Filter the DataFrames for delta_t = 10 and tau = 5\n",
    "filtered_params_entry = BETA_params_msa_entry[\n",
    "    (BETA_params_msa_entry['delta_t'] == delta_t_value) &\n",
    "    (BETA_params_msa_entry['tau'] == tau_value)\n",
    "]\n",
    "\n",
    "filtered_params_exit = BETA_params_msa_exit[\n",
    "    (BETA_params_msa_exit['delta_t'] == delta_t_value) &\n",
    "    (BETA_params_msa_exit['tau'] == tau_value)\n",
    "]\n",
    "\n",
    "# Exclude 'beta_Intercept' from beta columns\n",
    "beta_columns_entry = [col for col in filtered_params_entry.columns if col.startswith('beta_') and col != 'beta_Intercept']\n",
    "beta_columns_exit = [col for col in filtered_params_exit.columns if col.startswith('beta_') and col != 'beta_Intercept']\n",
    "\n",
    "# Ensure that beta columns are the same for both models\n",
    "beta_columns = list(set(beta_columns_entry) & set(beta_columns_exit))\n",
    "\n",
    "# If there are no common beta columns, print a message and exit\n",
    "if not beta_columns:\n",
    "    print(\"No common beta coefficients found between entry and exit models.\")\n",
    "else:\n",
    "    # Mapping from beta column names to LaTeX labels\n",
    "    latex_labels = {\n",
    "        'beta_RCA_mid': '$R_{cp}(t+\\\\tau)$',\n",
    "        'beta_RCA_start': '$R_{cp}(t)$',\n",
    "        'beta_Relatedness_start': '$\\\\omega_{cp}(t)$',\n",
    "        'beta_Relative_relatedness_start': '$\\\\tilde{\\\\omega}_{cp}(t)$'\n",
    "    }\n",
    "\n",
    "    # Set up the plotting style\n",
    "    sns.set(style='whitegrid')\n",
    "\n",
    "    # Number of beta parameters\n",
    "    num_beta_params = len(beta_columns)\n",
    "\n",
    "    # Create subplots with nrows=num_beta_params and ncols=2\n",
    "    fig, axes = plt.subplots(nrows=num_beta_params, ncols=2, figsize=(14, 4 * num_beta_params))\n",
    "\n",
    "    # If only one beta parameter, axes may not be a 2D array, so reshape\n",
    "    if num_beta_params == 1:\n",
    "        axes = np.array([axes])\n",
    "\n",
    "    # Plot histograms for each beta parameter\n",
    "    for i, beta_col in enumerate(beta_columns):\n",
    "        beta_label = latex_labels.get(beta_col, beta_col)\n",
    "\n",
    "        # Entry model histogram\n",
    "        ax_entry = axes[i, 0]\n",
    "        if not filtered_params_entry.empty:\n",
    "            sns.histplot(filtered_params_entry[beta_col], kde=True, bins=10, color='skyblue', ax=ax_entry)\n",
    "            ax_entry.set_title(f'Entry Model: Coefficient for {beta_label}')\n",
    "            ax_entry.set_xlabel('Coefficient Value')\n",
    "            ax_entry.set_ylabel('Frequency')\n",
    "        else:\n",
    "            ax_entry.text(0.5, 0.5, 'No data available', ha='center', va='center')\n",
    "            ax_entry.set_title(f'Entry Model: Coefficient for {beta_label}')\n",
    "            ax_entry.set_xlabel('Coefficient Value')\n",
    "            ax_entry.set_ylabel('Frequency')\n",
    "\n",
    "        # Exit model histogram\n",
    "        ax_exit = axes[i, 1]\n",
    "        if not filtered_params_exit.empty:\n",
    "            sns.histplot(filtered_params_exit[beta_col], kde=True, bins=10, color='salmon', ax=ax_exit)\n",
    "            ax_exit.set_title(f'Exit Model: Coefficient for {beta_label}')\n",
    "            ax_exit.set_xlabel('Coefficient Value')\n",
    "            ax_exit.set_ylabel('Frequency')\n",
    "        else:\n",
    "            ax_exit.text(0.5, 0.5, 'No data available', ha='center', va='center')\n",
    "            ax_exit.set_title(f'Exit Model: Coefficient for {beta_label}')\n",
    "            ax_exit.set_xlabel('Coefficient Value')\n",
    "            ax_exit.set_ylabel('Frequency')\n",
    "\n",
    "    # Adjust layout\n",
    "    plt.tight_layout()\n",
    "\n",
    "    # Save the figure as a high-quality PNG\n",
    "    plt.savefig('figure_s10.png', dpi=300, bbox_inches='tight')\n",
    "\n",
    "    # Display the figure\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "949e5cd5",
   "metadata": {},
   "source": [
    "### MSA LEVEL: Calculate coefficients used in the optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33b038cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the delta_t and tau values\n",
    "delta_t_value = 10\n",
    "tau_value = 5\n",
    "\n",
    "# Filter the DataFrames for delta_t = 10 and tau = 5\n",
    "filtered_params_entry = BETA_params_msa_entry[\n",
    "    (BETA_params_msa_entry['delta_t'] == delta_t_value) &\n",
    "    (BETA_params_msa_entry['tau'] == tau_value)\n",
    "]\n",
    "\n",
    "filtered_params_exit = BETA_params_msa_exit[\n",
    "    (BETA_params_msa_exit['delta_t'] == delta_t_value) &\n",
    "    (BETA_params_msa_exit['tau'] == tau_value)\n",
    "]\n",
    "\n",
    "# Exclude 'beta_Intercept' from beta columns\n",
    "beta_columns_entry = [col for col in filtered_params_entry.columns if col.startswith('beta_')]\n",
    "beta_columns_exit = [col for col in filtered_params_exit.columns if col.startswith('beta_')]\n",
    "\n",
    "# Check if there is data to compute the average\n",
    "if not filtered_params_entry.empty and beta_columns_entry:\n",
    "    # Compute the mean of the beta coefficients for the entry model\n",
    "    beta_msa_entry = filtered_params_entry[beta_columns_entry].mean()\n",
    "    print(\"Average beta coefficients for entry model (delta_t = 10, tau = 5):\")\n",
    "    print(beta_msa_entry)\n",
    "else:\n",
    "    print(\"No data available for entry model with delta_t = 10 and tau = 5.\")\n",
    "\n",
    "if not filtered_params_exit.empty and beta_columns_exit:\n",
    "    # Compute the mean of the beta coefficients for the exit model\n",
    "    beta_msa_exit = filtered_params_exit[beta_columns_exit].mean()\n",
    "    print(\"\\nAverage beta coefficients for exit model (delta_t = 10, tau = 5):\")\n",
    "    print(beta_msa_exit)\n",
    "else:\n",
    "    print(\"No data available for exit model with delta_t = 10 and tau = 5.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "161127e5-951b-4dbd-ae84-7f3681b530af",
   "metadata": {},
   "source": [
    "## Estimate Patents level models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2149ab8f-1e62-45e2-95f1-33fbe0e5ffc8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize lists to store the rows\n",
    "beta_params_entry_list = []\n",
    "beta_std_entry_list = []\n",
    "beta_pvalues_entry_list = []\n",
    "beta_rsquared_entry_list = []\n",
    "\n",
    "beta_params_exit_list = []\n",
    "beta_std_exit_list = []\n",
    "beta_pvalues_exit_list = []\n",
    "beta_rsquared_exit_list = []\n",
    "\n",
    "# Define the range of delta_t and tau\n",
    "min_delta_t = 8\n",
    "max_delta_t = 15  # Adjust as needed\n",
    "min_tau = 3\n",
    "\n",
    "# Loop over possible delta_t values\n",
    "for delta_t in range(min_delta_t, max_delta_t + 1):\n",
    "    max_tau = delta_t - 3  # tau must be less than delta_t\n",
    "    # Loop over possible tau values\n",
    "    for tau in range(min_tau, max_tau + 1):\n",
    "        # Calculate the maximum start_year so that end_year does not exceed 2022\n",
    "        max_start_year = 2021 - delta_t\n",
    "        # Loop over possible start_year values\n",
    "        for start_year in range(1999, max_start_year + 1):\n",
    "            mid_year = start_year + tau\n",
    "            end_year = start_year + delta_t\n",
    "\n",
    "            # Read data for start_year\n",
    "            filename = 'pct_data/pct_data_' + str(start_year) +'.csv'\n",
    "            data_start = pd.read_csv(filename) # Set \"Row\" as the index to get your desired matrix X_start = data_start.set_index(\"Row\")\n",
    "          \n",
    "            X_start = data_start.set_index(\"Row\")\n",
    "\n",
    "            pct_start = X_start.index.tolist()\n",
    "            categories_start = X_start.columns.tolist()\n",
    "\n",
    "            filename = 'pct_data/pct_data_' + str(mid_year) +'.csv'\n",
    "            data_mid = pd.read_csv(filename) # Set \"Row\" as the index to get your desired matrix X_start = data_start.set_index(\"Row\")\n",
    "\n",
    "            X_mid = data_mid.set_index(\"Row\")\n",
    "            \n",
    "            pct_mid = X_mid.index.tolist()\n",
    "  \n",
    "            categories_mid = X_mid.columns.tolist()\n",
    "\n",
    "            filename = 'pct_data/pct_data_' + str(end_year) +'.csv'\n",
    "            data_end = pd.read_csv(filename)\n",
    "\n",
    "            X_end = data_end.set_index(\"Row\")\n",
    "            \n",
    "            pct_end = X_end.index.tolist()\n",
    "            categories_end = X_end.columns.tolist()\n",
    "\n",
    "            # Find the intersection of pct and categories\n",
    "            pct_probit = np.intersect1d(pct_start, pct_mid)\n",
    "            pct_probit = np.intersect1d(pct_probit, pct_end)\n",
    "            categories_probit = np.intersect1d(categories_start, categories_mid)\n",
    "            categories_probit = np.intersect1d(categories_probit, categories_end)\n",
    "\n",
    "            # Skip iteration if no common pct or categories\n",
    "            if len(pct_probit) == 0 or len(categories_probit) == 0:\n",
    "                continue\n",
    "\n",
    "            # Function to find indices\n",
    "            def find_indices(original, to_find):\n",
    "                return [original.index(item) for item in to_find if item in original]\n",
    "\n",
    "            # Finding indices\n",
    "            z_pct_start = find_indices(pct_start, pct_probit)\n",
    "            z_pct_mid = find_indices(pct_mid, pct_probit)\n",
    "            z_pct_end = find_indices(pct_end, pct_probit)\n",
    "\n",
    "            z_categories_start = find_indices(categories_start, categories_probit)\n",
    "            z_categories_mid = find_indices(categories_mid, categories_probit)\n",
    "            z_categories_end = find_indices(categories_end, categories_probit)\n",
    "\n",
    "            # Creating matrices for start, mid, and end years using .iloc\n",
    "            X_mat_start = X_start.iloc[z_pct_start, z_categories_start]\n",
    "            X_mat_mid = X_mid.iloc[z_pct_mid, z_categories_mid]\n",
    "            X_mat_end = X_end.iloc[z_pct_end, z_categories_end]\n",
    "            \n",
    "            # Remove rows with sum less than 0\n",
    "            X_mat_start = X_mat_start[X_mat_end.sum(axis=1) > 0]\n",
    "            # Remove columns with sum less than 0\n",
    "            X_mat_start = X_mat_start.loc[:, X_mat_end.sum(axis=0) > 0]\n",
    "\n",
    "            # Remove rows with sum less than 0\n",
    "            X_mat_mid = X_mat_mid[X_mat_end.sum(axis=1) > 0]\n",
    "            # Remove columns with sum less than 0\n",
    "            X_mat_mid = X_mat_mid.loc[:, X_mat_end.sum(axis=0) > 0]\n",
    "\n",
    "            # Remove rows with sum less than 0\n",
    "            X_mat_end = X_mat_end[X_mat_end.sum(axis=1) > 0]\n",
    "            # Remove columns with sum less than 0\n",
    "            X_mat_end = X_mat_end.loc[:, X_mat_end.sum(axis=0) > 0]\n",
    "\n",
    "            pct_probit = X_mat_end.index.tolist()\n",
    "            categories_probit = X_mat_end.columns.tolist()            \n",
    "\n",
    "            # Assuming rca and cplex_rank functions are already defined\n",
    "            RCA_start = eciopt.rca(X_mat_start)\n",
    "            RCA_mid = eciopt.rca(X_mat_mid)\n",
    "            RCA_end = eciopt.rca(X_mat_end)\n",
    "\n",
    "            M_start = (RCA_start > 1).astype(float)\n",
    "            M_mid = (RCA_mid > 1).astype(float)\n",
    "            M_end = (RCA_end > 1).astype(float)\n",
    "\n",
    "            # Run cplex_rank function for start matrices\n",
    "            try:\n",
    "                _, _, Relatedness_start, _ = eciopt.cplex_rank(M_start, pct_probit, categories_probit)\n",
    "            except Exception as e:\n",
    "                print(f\"Error in cplex_rank for start_year {start_year}, delta_t {delta_t}, tau {tau}: {e}\")\n",
    "                continue\n",
    "\n",
    "            # Flatten the Relatedness_start matrix\n",
    "            Relatedness_start = Relatedness_start.flatten()\n",
    "\n",
    "            # Create repeated arrays for pct and products\n",
    "            pct_all = np.repeat(pct_probit, len(categories_probit))\n",
    "            products_all = np.tile(categories_probit, len(pct_probit))\n",
    "\n",
    "            RR = RCA_start.flatten()\n",
    "            RR_mid = RCA_mid.flatten()\n",
    "            RR_end = RCA_end.flatten()\n",
    "\n",
    "            # Create a DataFrame for analysis\n",
    "            probit_data = pd.DataFrame({\n",
    "                'pct_all': pct_all,\n",
    "                'products_all': products_all,\n",
    "                'M_end': M_end.flatten(),\n",
    "                'Relatedness_start': Relatedness_start,\n",
    "                'RCA_start': np.log(1 + RR),\n",
    "                'RCA_mid': np.log(1 + RR_mid),\n",
    "                'RCA_end': np.log(1 + RR_end)\n",
    "            })\n",
    "\n",
    "            # Split the data into 'entry' and 'exit' subsets\n",
    "            entry_data = probit_data[probit_data['RCA_start'] < np.log(2)].copy()\n",
    "            exit_data = probit_data[probit_data['RCA_start'] >= np.log(2)].copy()\n",
    "\n",
    "            # Skip if there are not enough data points\n",
    "            if entry_data.empty or exit_data.empty:\n",
    "                continue\n",
    "\n",
    "            # Calculate the z-score of Relatedness_start separately for each subset\n",
    "            entry_data['Relative_relatedness_start'] = entry_data.groupby('pct_all')['Relatedness_start'].transform(zscore)\n",
    "            exit_data['Relative_relatedness_start'] = exit_data.groupby('pct_all')['Relatedness_start'].transform(zscore)\n",
    "\n",
    "            # Define the model formula\n",
    "            formula = 'RCA_end ~ RCA_mid + RCA_start + Relatedness_start + Relative_relatedness_start'\n",
    "\n",
    "            # Fit the model for the 'entry' subset\n",
    "            try:\n",
    "                model_entry = ols(formula=formula, data=entry_data)\n",
    "                result_entry = model_entry.fit()\n",
    "            except Exception as e:\n",
    "                print(f\"Error fitting entry model for start_year {start_year}, delta_t {delta_t}, tau {tau}: {e}\")\n",
    "                continue\n",
    "\n",
    "            # Fit the model for the 'exit' subset\n",
    "            try:\n",
    "                model_exit = ols(formula=formula, data=exit_data)\n",
    "                result_exit = model_exit.fit()\n",
    "            except Exception as e:\n",
    "                print(f\"Error fitting exit model for start_year {start_year}, delta_t {delta_t}, tau {tau}: {e}\")\n",
    "                continue\n",
    "\n",
    "             # Extract parameters, standard errors, and p-values\n",
    "            beta_params_entry = result_entry.params\n",
    "            beta_std_entry = result_entry.bse\n",
    "            beta_pvalues_entry = result_entry.pvalues\n",
    "            beta_rsquared_entry = result_entry.rsquared\n",
    "\n",
    "            beta_params_exit = result_exit.params\n",
    "            beta_std_exit = result_exit.bse\n",
    "            beta_pvalues_exit = result_exit.pvalues\n",
    "            beta_rsquared_exit = result_exit.rsquared            \n",
    "\n",
    "            # Prepare rows to append to the lists\n",
    "            row_params_entry = {'start_year': start_year, 'mid_year': mid_year, 'end_year': end_year, 'delta_t': delta_t, 'tau': tau}\n",
    "            row_std_entry = {'start_year': start_year, 'mid_year': mid_year, 'end_year': end_year, 'delta_t': delta_t, 'tau': tau}\n",
    "            row_pvalues_entry = {'start_year': start_year, 'mid_year': mid_year, 'end_year': end_year, 'delta_t': delta_t, 'tau': tau}\n",
    "            row_rsquared_entry = {'start_year': start_year, 'mid_year': mid_year, 'end_year': end_year, 'delta_t': delta_t, 'tau': tau, 'R2': beta_rsquared_entry}\n",
    "            \n",
    "            row_params_exit = {'start_year': start_year, 'mid_year': mid_year, 'end_year': end_year, 'delta_t': delta_t, 'tau': tau}\n",
    "            row_std_exit = {'start_year': start_year, 'mid_year': mid_year, 'end_year': end_year, 'delta_t': delta_t, 'tau': tau}\n",
    "            row_pvalues_exit = {'start_year': start_year, 'mid_year': mid_year, 'end_year': end_year, 'delta_t': delta_t, 'tau': tau}\n",
    "            row_rsquared_exit = {'start_year': start_year, 'mid_year': mid_year, 'end_year': end_year, 'delta_t': delta_t, 'tau': tau, 'R2': beta_rsquared_exit}\n",
    "            \n",
    "            # Add parameters to the rows\n",
    "            for param_name in beta_params_entry.index:\n",
    "                row_params_entry['beta_' + param_name] = beta_params_entry[param_name]\n",
    "                row_std_entry['beta_' + param_name] = beta_std_entry[param_name]\n",
    "                row_pvalues_entry['beta_' + param_name] = beta_pvalues_entry[param_name]\n",
    "\n",
    "            for param_name in beta_params_exit.index:\n",
    "                row_params_exit['beta_' + param_name] = beta_params_exit[param_name]\n",
    "                row_std_exit['beta_' + param_name] = beta_std_exit[param_name]\n",
    "                row_pvalues_exit['beta_' + param_name] = beta_pvalues_exit[param_name]\n",
    "\n",
    "            # Append the rows to the lists\n",
    "            beta_params_entry_list.append(row_params_entry)\n",
    "            beta_std_entry_list.append(row_std_entry)\n",
    "            beta_pvalues_entry_list.append(row_pvalues_entry)\n",
    "            beta_rsquared_entry_list.append(row_rsquared_entry)            \n",
    "\n",
    "            beta_params_exit_list.append(row_params_exit)\n",
    "            beta_std_exit_list.append(row_std_exit)\n",
    "            beta_pvalues_exit_list.append(row_pvalues_exit)\n",
    "            beta_rsquared_exit_list.append(row_rsquared_exit)            \n",
    "            \n",
    "            display(\"Start year:\", start_year)\n",
    "            display(\"Mid year:\", mid_year)\n",
    "            display(\"End year:\", end_year)\n",
    "\n",
    "# After the loops, create DataFrames from the lists\n",
    "BETA_params_pct_entry = pd.DataFrame(beta_params_entry_list)\n",
    "BETA_std_pct_entry = pd.DataFrame(beta_std_entry_list)\n",
    "BETA_pvalue_pct_entry = pd.DataFrame(beta_pvalues_entry_list)\n",
    "BETA_rsquared_pct_entry = pd.DataFrame(beta_rsquared_entry_list)\n",
    "\n",
    "\n",
    "BETA_params_pct_exit = pd.DataFrame(beta_params_exit_list)\n",
    "BETA_std_pct_exit = pd.DataFrame(beta_std_exit_list)\n",
    "BETA_pvalue_pct_exit = pd.DataFrame(beta_pvalues_exit_list)\n",
    "BETA_rsquared_pct_exit = pd.DataFrame(beta_rsquared_exit_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3409d7c3-d045-4a1b-8d7b-a6b725dc9a39",
   "metadata": {},
   "source": [
    "### Figure S11 Patents LEVEL: Heatmaps for the dependence of the ENTRY model coefficients on \\Delta t and \\tau\n",
    "\n",
    "### Figure S12 Patents LEVEL. Heatmaps showing the coefficient of determination of the ENTRY models as a function of  \\Delta t and \\tau"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7fd6c10-6375-4642-9b8d-e96ef738e5f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ensure delta_t and tau are in all DataFrames\n",
    "for df in [BETA_params_pct_entry, BETA_std_pct_entry, BETA_pvalue_pct_entry]:\n",
    "    df['delta_t'] = df['end_year'] - df['start_year']\n",
    "    df['tau'] = df['mid_year'] - df['start_year']\n",
    "    df['delta_t'] = df['delta_t'].astype(int)\n",
    "    df['tau'] = df['tau'].astype(int)\n",
    "\n",
    "# Exclude 'beta_Intercept' from beta columns\n",
    "beta_columns = [col for col in BETA_params_pct_entry.columns if col.startswith('beta_') and col != 'beta_Intercept']\n",
    "\n",
    "# Mapping from beta column names to LaTeX labels\n",
    "latex_labels = {\n",
    "    'beta_RCA_mid': 'Coefficient of $R_{cp}(t+\\\\tau)$',\n",
    "    'beta_RCA_start': 'Coefficient of $R_{cp}(t)$',\n",
    "    'beta_Relatedness_start': 'Coefficient of $\\\\omega_{cp}(t)$',\n",
    "    'beta_Relative_relatedness_start': 'Coefficient of $\\\\tilde{\\\\omega}_{cp}(t)$'\n",
    "}\n",
    "\n",
    "# Update delta_t and tau labels for axes using LaTeX\n",
    "delta_t_label = '$\\\\Delta t$'\n",
    "tau_label = '$\\\\tau$'\n",
    "\n",
    "# Group by delta_t and tau and calculate the mean of beta parameters\n",
    "grouped_params = BETA_params_pct_entry.groupby(['delta_t', 'tau'])[beta_columns].mean().reset_index()\n",
    "grouped_std = BETA_std_pct_entry.groupby(['delta_t', 'tau'])[beta_columns].mean().reset_index()\n",
    "grouped_pvalues = BETA_pvalue_pct_entry.groupby(['delta_t', 'tau'])[beta_columns].mean().reset_index()\n",
    "\n",
    "# Number of beta parameters to plot (excluding intercept)\n",
    "num_beta_params = len(beta_columns)\n",
    "\n",
    "# Create a figure with subplots\n",
    "fig, axes = plt.subplots(nrows=num_beta_params, ncols=3, figsize=(18, 4 * num_beta_params))\n",
    "\n",
    "# If only one beta parameter, ensure axes is a 2D array\n",
    "if num_beta_params == 1:\n",
    "    axes = np.array([axes])\n",
    "\n",
    "# For each beta parameter, create pivot tables and plot heatmaps\n",
    "for i, beta_col in enumerate(beta_columns):\n",
    "    # Pivot tables for beta parameter, standard error, and p-value\n",
    "    pivot_params = grouped_params.pivot(index='tau', columns='delta_t', values=beta_col).sort_index().sort_index(axis=1)\n",
    "    pivot_std = grouped_std.pivot(index='tau', columns='delta_t', values=beta_col).sort_index().sort_index(axis=1)\n",
    "    pivot_pvalues = grouped_pvalues.pivot(index='tau', columns='delta_t', values=beta_col).sort_index().sort_index(axis=1)\n",
    "    \n",
    "    # Get the LaTeX label for the beta parameter\n",
    "    beta_label = latex_labels.get(beta_col, beta_col)\n",
    "    \n",
    "    # Plot beta parameter heatmap\n",
    "    sns.heatmap(pivot_params, annot=True, fmt=\".2f\", cmap='coolwarm', cbar_kws={'label': beta_label}, ax=axes[i, 0])\n",
    "    axes[i, 0].set_title(beta_label)\n",
    "    axes[i, 0].set_ylabel(tau_label)\n",
    "    axes[i, 0].set_xlabel(delta_t_label)\n",
    "    \n",
    "    # Plot standard error heatmap\n",
    "    sns.heatmap(pivot_std, annot=True, fmt=\".2f\", cmap='viridis', cbar_kws={'label': 'Std Error'}, ax=axes[i, 1])\n",
    "    axes[i, 1].set_title(f'Std Error of {beta_label}')\n",
    "    axes[i, 1].set_ylabel(tau_label)\n",
    "    axes[i, 1].set_xlabel(delta_t_label)\n",
    "    \n",
    "    # Plot p-value heatmap\n",
    "    sns.heatmap(pivot_pvalues, annot=True, fmt=\".2f\", cmap='magma_r', cbar_kws={'label': 'p-value'}, ax=axes[i, 2])\n",
    "    axes[i, 2].set_title(f'p-value of {beta_label}')\n",
    "    axes[i, 2].set_ylabel(tau_label)\n",
    "    axes[i, 2].set_xlabel(delta_t_label)\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig('figure_s11.png', dpi=300)\n",
    "plt.show()\n",
    "\n",
    "# Assuming R² values have been computed and added to the DataFrame\n",
    "# Add a column 'R2' to each DataFrame, if not already calculated\n",
    "BETA_rsquared_pct_entry['delta_t'] = BETA_params_pct_entry['delta_t']\n",
    "BETA_rsquared_pct_entry['tau'] = BETA_params_pct_entry['tau']\n",
    "\n",
    "# Group by delta_t and tau to get the mean R² value\n",
    "grouped_R2 = BETA_rsquared_pct_entry.groupby(['delta_t', 'tau'])['R2'].mean().reset_index()\n",
    "\n",
    "# Pivot table for R² values\n",
    "pivot_R2 = grouped_R2.pivot(index='tau', columns='delta_t', values='R2').sort_index().sort_index(axis=1)\n",
    "\n",
    "# Create the heatmap\n",
    "plt.figure(figsize=(8, 6))\n",
    "sns.heatmap(pivot_R2, annot=True, fmt=\".2f\", cmap='coolwarm', cbar_kws={'label': 'Average $R^2$'})\n",
    "plt.title('Average $R^2$ of entry models')\n",
    "plt.ylabel(tau_label)\n",
    "plt.xlabel(delta_t_label)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('figure_s12.png', dpi=300)\n",
    "plt.show()\n",
    "\n",
    "print(pivot_R2.mean().mean())\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2909a919-a316-46af-85b8-b95982551d52",
   "metadata": {},
   "source": [
    "### Figure S13 Patents LEVEL:  Heatmaps for the dependence of the EXIT model coefficients on \\Delta t and \\tau\n",
    "\n",
    "### Figure S14 Patents LEVEL: Heatmaps showing the coefficient of determination of the EXIT models as a function of  \\Delta t and \\tau"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1585e33e-b9ed-4046-af1d-f16decffb0e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ensure delta_t and tau are in all DataFrames\n",
    "for df in [BETA_params_pct_exit, BETA_std_pct_exit, BETA_pvalue_pct_exit]:\n",
    "    df['delta_t'] = df['end_year'] - df['start_year']\n",
    "    df['tau'] = df['mid_year'] - df['start_year']\n",
    "    df['delta_t'] = df['delta_t'].astype(int)\n",
    "    df['tau'] = df['tau'].astype(int)\n",
    "\n",
    "# Exclude 'beta_Intercept' from beta columns\n",
    "beta_columns = [col for col in BETA_params_pct_exit.columns if col.startswith('beta_') and col != 'beta_Intercept']\n",
    "\n",
    "# Mapping from beta column names to LaTeX labels\n",
    "latex_labels = {\n",
    "    'beta_RCA_mid': 'Coefficient of $R_{cp}(t+\\\\tau)$',\n",
    "    'beta_RCA_start': 'Coefficient of $R_{cp}(t)$',\n",
    "    'beta_Relatedness_start': 'Coefficient of $\\\\omega_{cp}(t)$',\n",
    "    'beta_Relative_relatedness_start': 'Coefficient of $\\\\tilde{\\\\omega}_{cp}(t)$'\n",
    "}\n",
    "\n",
    "# Update delta_t and tau labels for axes using LaTeX\n",
    "delta_t_label = '$\\\\Delta t$'\n",
    "tau_label = '$\\\\tau$'\n",
    "\n",
    "# Group by delta_t and tau and calculate the mean of beta parameters\n",
    "grouped_params = BETA_params_pct_exit.groupby(['delta_t', 'tau'])[beta_columns].mean().reset_index()\n",
    "grouped_std = BETA_std_pct_exit.groupby(['delta_t', 'tau'])[beta_columns].mean().reset_index()\n",
    "grouped_pvalues = BETA_pvalue_pct_exit.groupby(['delta_t', 'tau'])[beta_columns].mean().reset_index()\n",
    "\n",
    "# Number of beta parameters to plot (excluding intercept)\n",
    "num_beta_params = len(beta_columns)\n",
    "\n",
    "# Create a figure with subplots\n",
    "fig, axes = plt.subplots(nrows=num_beta_params, ncols=3, figsize=(18, 4 * num_beta_params))\n",
    "\n",
    "# If only one beta parameter, ensure axes is a 2D array\n",
    "if num_beta_params == 1:\n",
    "    axes = np.array([axes])\n",
    "\n",
    "# For each beta parameter, create pivot tables and plot heatmaps\n",
    "for i, beta_col in enumerate(beta_columns):\n",
    "    # Pivot tables for beta parameter, standard error, and p-value\n",
    "    pivot_params = grouped_params.pivot(index='tau', columns='delta_t', values=beta_col).sort_index().sort_index(axis=1)\n",
    "    pivot_std = grouped_std.pivot(index='tau', columns='delta_t', values=beta_col).sort_index().sort_index(axis=1)\n",
    "    pivot_pvalues = grouped_pvalues.pivot(index='tau', columns='delta_t', values=beta_col).sort_index().sort_index(axis=1)\n",
    "    \n",
    "    # Get the LaTeX label for the beta parameter\n",
    "    beta_label = latex_labels.get(beta_col, beta_col)\n",
    "    \n",
    "    # Plot beta parameter heatmap\n",
    "    sns.heatmap(pivot_params, annot=True, fmt=\".2f\", cmap='coolwarm', cbar_kws={'label': beta_label}, ax=axes[i, 0])\n",
    "    axes[i, 0].set_title(beta_label)\n",
    "    axes[i, 0].set_ylabel(tau_label)\n",
    "    axes[i, 0].set_xlabel(delta_t_label)\n",
    "    \n",
    "    # Plot standard error heatmap\n",
    "    sns.heatmap(pivot_std, annot=True, fmt=\".2f\", cmap='viridis', cbar_kws={'label': 'Std Error'}, ax=axes[i, 1])\n",
    "    axes[i, 1].set_title(f'Std Error of {beta_label}')\n",
    "    axes[i, 1].set_ylabel(tau_label)\n",
    "    axes[i, 1].set_xlabel(delta_t_label)\n",
    "    \n",
    "    # Plot p-value heatmap\n",
    "    sns.heatmap(pivot_pvalues, annot=True, fmt=\".2f\", cmap='magma_r', cbar_kws={'label': 'p-value'}, ax=axes[i, 2])\n",
    "    axes[i, 2].set_title(f'p-value of {beta_label}')\n",
    "    axes[i, 2].set_ylabel(tau_label)\n",
    "    axes[i, 2].set_xlabel(delta_t_label)\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig('figure_s13.png', dpi=300)\n",
    "plt.show()\n",
    "\n",
    "# Assuming R² values have been computed and added to the DataFrame\n",
    "# Add a column 'R2' to each DataFrame, if not already calculated\n",
    "BETA_rsquared_pct_exit['delta_t'] = BETA_params_pct_exit['delta_t']\n",
    "BETA_rsquared_pct_exit['tau'] = BETA_params_pct_exit['tau']\n",
    "\n",
    "# Group by delta_t and tau to get the mean R² value\n",
    "grouped_R2 = BETA_rsquared_pct_exit.groupby(['delta_t', 'tau'])['R2'].mean().reset_index()\n",
    "\n",
    "# Pivot table for R² values\n",
    "pivot_R2 = grouped_R2.pivot(index='tau', columns='delta_t', values='R2').sort_index().sort_index(axis=1)\n",
    "\n",
    "# Create the heatmap\n",
    "plt.figure(figsize=(8, 6))\n",
    "sns.heatmap(pivot_R2, annot=True, fmt=\".2f\", cmap='coolwarm', cbar_kws={'label': 'Average $R^2$'})\n",
    "plt.title('Average $R^2$ of exit models')\n",
    "plt.ylabel(tau_label)\n",
    "plt.xlabel(delta_t_label)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('figure_s14.png', dpi=300)\n",
    "plt.show()\n",
    "\n",
    "\n",
    "print(pivot_R2.mean().mean())\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f06afa5a-59b4-4b0b-ba66-56772c078fea",
   "metadata": {},
   "source": [
    "### Figure S15 Patents LEVEL:  Histograms for the model coefficients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5c0d0d6-3085-420f-8c73-620a76fe8d1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the delta_t and tau values\n",
    "delta_t_value = 10\n",
    "tau_value = 5\n",
    "\n",
    "# Filter the DataFrames for delta_t = 10 and tau = 5\n",
    "filtered_params_entry = BETA_params_pct_entry[\n",
    "    (BETA_params_pct_entry['delta_t'] == delta_t_value) &\n",
    "    (BETA_params_pct_entry['tau'] == tau_value)\n",
    "]\n",
    "\n",
    "filtered_params_exit = BETA_params_pct_exit[\n",
    "    (BETA_params_pct_exit['delta_t'] == delta_t_value) &\n",
    "    (BETA_params_pct_exit['tau'] == tau_value)\n",
    "]\n",
    "\n",
    "# Exclude 'beta_Intercept' from beta columns\n",
    "beta_columns_entry = [col for col in filtered_params_entry.columns if col.startswith('beta_') and col != 'beta_Intercept']\n",
    "beta_columns_exit = [col for col in filtered_params_exit.columns if col.startswith('beta_') and col != 'beta_Intercept']\n",
    "\n",
    "# Ensure that beta columns are the same for both models\n",
    "beta_columns = list(set(beta_columns_entry) & set(beta_columns_exit))\n",
    "\n",
    "# If there are no common beta columns, print a message and exit\n",
    "if not beta_columns:\n",
    "    print(\"No common beta coefficients found between entry and exit models.\")\n",
    "else:\n",
    "    # Mapping from beta column names to LaTeX labels\n",
    "    latex_labels = {\n",
    "        'beta_RCA_mid': '$R_{cp}(t+\\\\tau)$',\n",
    "        'beta_RCA_start': '$R_{cp}(t)$',\n",
    "        'beta_Relatedness_start': '$\\\\omega_{cp}(t)$',\n",
    "        'beta_Relative_relatedness_start': '$\\\\tilde{\\\\omega}_{cp}(t)$'\n",
    "    }\n",
    "\n",
    "    # Set up the plotting style\n",
    "    sns.set(style='whitegrid')\n",
    "\n",
    "    # Number of beta parameters\n",
    "    num_beta_params = len(beta_columns)\n",
    "\n",
    "    # Create subplots with nrows=num_beta_params and ncols=2\n",
    "    fig, axes = plt.subplots(nrows=num_beta_params, ncols=2, figsize=(14, 4 * num_beta_params))\n",
    "\n",
    "    # If only one beta parameter, axes may not be a 2D array, so reshape\n",
    "    if num_beta_params == 1:\n",
    "        axes = np.array([axes])\n",
    "\n",
    "    # Plot histograms for each beta parameter\n",
    "    for i, beta_col in enumerate(beta_columns):\n",
    "        beta_label = latex_labels.get(beta_col, beta_col)\n",
    "\n",
    "        # Entry model histogram\n",
    "        ax_entry = axes[i, 0]\n",
    "        if not filtered_params_entry.empty:\n",
    "            sns.histplot(filtered_params_entry[beta_col], kde=True, bins=10, color='skyblue', ax=ax_entry)\n",
    "            ax_entry.set_title(f'Entry Model: Coefficient for {beta_label}')\n",
    "            ax_entry.set_xlabel('Coefficient Value')\n",
    "            ax_entry.set_ylabel('Frequency')\n",
    "        else:\n",
    "            ax_entry.text(0.5, 0.5, 'No data available', ha='center', va='center')\n",
    "            ax_entry.set_title(f'Entry Model: Coefficient for {beta_label}')\n",
    "            ax_entry.set_xlabel('Coefficient Value')\n",
    "            ax_entry.set_ylabel('Frequency')\n",
    "\n",
    "        # Exit model histogram\n",
    "        ax_exit = axes[i, 1]\n",
    "        if not filtered_params_exit.empty:\n",
    "            sns.histplot(filtered_params_exit[beta_col], kde=True, bins=10, color='salmon', ax=ax_exit)\n",
    "            ax_exit.set_title(f'Exit Model: Coefficient for {beta_label}')\n",
    "            ax_exit.set_xlabel('Coefficient Value')\n",
    "            ax_exit.set_ylabel('Frequency')\n",
    "        else:\n",
    "            ax_exit.text(0.5, 0.5, 'No data available', ha='center', va='center')\n",
    "            ax_exit.set_title(f'Exit Model: Coefficient for {beta_label}')\n",
    "            ax_exit.set_xlabel('Coefficient Value')\n",
    "            ax_exit.set_ylabel('Frequency')\n",
    "\n",
    "    # Adjust layout\n",
    "    plt.tight_layout()\n",
    "\n",
    "    # Save the figure as a high-quality PNG\n",
    "    plt.savefig('figure_s15.png', dpi=300, bbox_inches='tight')\n",
    "\n",
    "    # Display the figure\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12a24fe5-8c7c-4c91-8f6a-abdbcdc9309c",
   "metadata": {},
   "source": [
    "### Patents LEVEL: Calculate coefficients used in the optimizationodel with delta_t = 10 and tau = 5."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61a155cd-96ce-4cb9-91f1-a2ceef17a4ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the delta_t and tau values\n",
    "delta_t_value = 10\n",
    "tau_value = 5\n",
    "\n",
    "# Filter the DataFrames for delta_t = 10 and tau = 5\n",
    "filtered_params_entry = BETA_params_pct_entry[\n",
    "    (BETA_params_pct_entry['delta_t'] == delta_t_value) &\n",
    "    (BETA_params_pct_entry['tau'] == tau_value)\n",
    "]\n",
    "\n",
    "filtered_params_exit = BETA_params_pct_exit[\n",
    "    (BETA_params_pct_exit['delta_t'] == delta_t_value) &\n",
    "    (BETA_params_pct_exit['tau'] == tau_value)\n",
    "]\n",
    "\n",
    "# Exclude 'beta_Intercept' from beta columns\n",
    "beta_columns_entry = [col for col in filtered_params_entry.columns if col.startswith('beta_')]\n",
    "beta_columns_exit = [col for col in filtered_params_exit.columns if col.startswith('beta_')]\n",
    "\n",
    "# Check if there is data to compute the average\n",
    "if not filtered_params_entry.empty and beta_columns_entry:\n",
    "    # Compute the mean of the beta coefficients for the entry model\n",
    "    beta_pct_entry = filtered_params_entry[beta_columns_entry].mean()\n",
    "    print(\"Average beta coefficients for entry model (delta_t = 10, tau = 5):\")\n",
    "    print(beta_pct_entry)\n",
    "else:\n",
    "    print(\"No data available for entry model with delta_t = 10 and tau = 5.\")\n",
    "\n",
    "if not filtered_params_exit.empty and beta_columns_exit:\n",
    "    # Compute the mean of the beta coefficients for the exit model\n",
    "    beta_pct_exit = filtered_params_exit[beta_columns_exit].mean()\n",
    "    print(\"\\nAverage beta coefficients for exit model (delta_t = 10, tau = 5):\")\n",
    "    print(beta_pct_exit)\n",
    "else:\n",
    "    print(\"No data available for exit model with delta_t = 10 and tau = 5.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a26a39c",
   "metadata": {},
   "source": [
    "### TABLE 1 Entry and Exit Model for Target Year 2022"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd3c83ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "from stargazer.stargazer import Stargazer\n",
    "from IPython.display import display, HTML\n",
    "\n",
    "start_year = 2012\n",
    "mid_year = 2017\n",
    "end_year = 2022\n",
    "\n",
    "# Read data for start_year\n",
    "filename = 'trade_data/bilateral_' + str(start_year) + '.csv'\n",
    "data_start = pd.read_csv(filename)\n",
    "\n",
    "\n",
    "X_start = data_start.pivot_table(values='value', index='exporter_id', columns='hs_code', \n",
    "                                 aggfunc=np.sum, fill_value=0)\n",
    "\n",
    "# Remove rows with sum less than 10^9\n",
    "X_start = X_start[X_start.sum(axis=1) >= 10**9]\n",
    "\n",
    "# Remove columns with sum less than 500,000\n",
    "X_start = X_start.loc[:, X_start.sum(axis=0) >= 500000]\n",
    "\n",
    "countries_start = X_start.index.tolist()\n",
    "categories_start = X_start.columns.tolist()\n",
    "\n",
    "# Read data for mid_year\n",
    "filename = 'trade_data/bilateral_' + str(mid_year) + '.csv'\n",
    "data_mid = pd.read_csv(filename)\n",
    "\n",
    "\n",
    "X_mid = data_mid.pivot_table(values='value', index='exporter_id', columns='hs_code', \n",
    "                             aggfunc=np.sum, fill_value=0)\n",
    "\n",
    "# Remove rows with sum less than 10^9\n",
    "X_mid = X_mid[X_mid.sum(axis=1) >= 10**9]\n",
    "\n",
    "# Remove columns with sum less than 500,000\n",
    "X_mid = X_mid.loc[:, X_mid.sum(axis=0) >= 500000]\n",
    "\n",
    "countries_mid = X_mid.index.tolist()\n",
    "categories_mid = X_mid.columns.tolist()\n",
    "\n",
    "# Read data for end_year\n",
    "filename = 'trade_data/bilateral_' + str(end_year) + '.csv'\n",
    "data_end = pd.read_csv(filename)\n",
    "\n",
    "\n",
    "X_end = data_end.pivot_table(values='value', index='exporter_id', columns='hs_code', \n",
    "                             aggfunc=np.sum, fill_value=0)\n",
    "\n",
    "# Remove rows with sum less than 10^9\n",
    "X_end = X_end[X_end.sum(axis=1) >= 10**9]\n",
    "\n",
    "# Remove columns with sum less than 500,000\n",
    "X_end = X_end.loc[:, X_end.sum(axis=0) >= 500000]\n",
    "\n",
    "countries_end = X_end.index.tolist()\n",
    "categories_end = X_end.columns.tolist()\n",
    "\n",
    "# Find the intersection of countries and categories\n",
    "countries_probit = np.intersect1d(countries_start, countries_mid)\n",
    "countries_probit = np.intersect1d(countries_probit, countries_end)\n",
    "categories_probit = np.intersect1d(categories_start, categories_mid)\n",
    "categories_probit = np.intersect1d(categories_probit, categories_end)\n",
    "\n",
    "\n",
    "# Function to find indices\n",
    "def find_indices(original, to_find):\n",
    "    return [original.index(item) for item in to_find if item in original]\n",
    "\n",
    "# Finding indices\n",
    "z_countries_start = find_indices(countries_start, countries_probit)\n",
    "z_countries_mid = find_indices(countries_mid, countries_probit)\n",
    "z_countries_end = find_indices(countries_end, countries_probit)\n",
    "\n",
    "z_categories_start = find_indices(categories_start, categories_probit)\n",
    "z_categories_mid = find_indices(categories_mid, categories_probit)\n",
    "z_categories_end = find_indices(categories_end, categories_probit)\n",
    "\n",
    "# Creating matrices for start, mid, and end years using .iloc\n",
    "X_mat_start = X_start.iloc[z_countries_start, z_categories_start]\n",
    "X_mat_mid = X_mid.iloc[z_countries_mid, z_categories_mid]\n",
    "X_mat_end = X_end.iloc[z_countries_end, z_categories_end]\n",
    "\n",
    "# Assuming rca and cplex_rank functions are already defined\n",
    "RCA_start = eciopt.rca(X_mat_start)\n",
    "RCA_mid = eciopt.rca(X_mat_mid)\n",
    "RCA_end = eciopt.rca(X_mat_end)\n",
    "\n",
    "M_start = (RCA_start > 1).astype(float)\n",
    "\n",
    "_, _, Relatedness_start, _ = eciopt.cplex_rank(M_start, countries_probit, categories_probit)\n",
    "\n",
    "\n",
    "# Flatten the Relatedness_start matrix\n",
    "Relatedness_start = Relatedness_start.flatten()\n",
    "\n",
    "# Create repeated arrays for countries and products\n",
    "countries_all = np.repeat(countries_probit, len(categories_probit))\n",
    "products_all = np.tile(categories_probit, len(countries_probit))\n",
    "\n",
    "RR = RCA_start.flatten()\n",
    "RR_mid = RCA_mid.flatten()\n",
    "RR_end = RCA_end.flatten()\n",
    "\n",
    "# Create a DataFrame for analysis\n",
    "probit_data = pd.DataFrame({\n",
    "    'countries_all': countries_all,\n",
    "    'products_all': products_all,\n",
    "    'Relatedness_start': Relatedness_start,\n",
    "    'RCA_start': np.log(1 + RR),\n",
    "    'RCA_mid': np.log(1 + RR_mid),\n",
    "    'RCA_end': np.log(1 + RR_end)\n",
    "})\n",
    "\n",
    "# Split the data into 'entry' and 'exit' subsets\n",
    "entry_data = probit_data[probit_data['RCA_start'] < np.log(2)].copy()\n",
    "exit_data = probit_data[probit_data['RCA_start'] >= np.log(2)].copy()\n",
    "\n",
    "\n",
    "# Calculate the z-score of Relatedness_start separately for each subset\n",
    "entry_data['Relative_relatedness_start'] = entry_data.groupby('countries_all')['Relatedness_start'].transform(zscore)\n",
    "exit_data['Relative_relatedness_start'] = exit_data.groupby('countries_all')['Relatedness_start'].transform(zscore)\n",
    "\n",
    "# Define the model formula\n",
    "formula = 'RCA_end ~ RCA_mid + RCA_start + Relatedness_start + Relative_relatedness_start'\n",
    "\n",
    "formula_RCA = 'RCA_end ~ RCA_mid + RCA_start'\n",
    "formula_Relatedness = 'RCA_end ~ Relatedness_start + Relative_relatedness_start'\n",
    "formula_all = 'RCA_end ~ RCA_mid + RCA_start + Relatedness_start + Relative_relatedness_start'\n",
    "\n",
    "# Fit the models\n",
    "model_entry_RCA = ols(formula=formula_RCA, data=entry_data).fit()\n",
    "model_entry_Relatedness = ols(formula=formula_Relatedness, data=entry_data).fit()\n",
    "model_entry_all = ols(formula=formula_all, data=entry_data).fit()\n",
    "\n",
    "model_exit_RCA = ols(formula=formula_RCA, data=exit_data).fit()\n",
    "model_exit_Relatedness = ols(formula=formula_Relatedness, data=exit_data).fit()\n",
    "model_exit_all = ols(formula=formula_all, data=exit_data).fit()\n",
    "\n",
    "# Create a Stargazer object\n",
    "\n",
    "stargazer = Stargazer([model_entry_RCA, model_entry_Relatedness, model_entry_all, model_exit_RCA, model_exit_Relatedness, model_exit_all])\n",
    "\n",
    "# Covariate order and renaming\n",
    "stargazer.covariate_order([\n",
    "    'RCA_mid',                # Main effect of RCA_mid\n",
    "    'RCA_start',              # Main effect of RCA_start\n",
    "    'Relatedness_start',      # Main effect of Relatedness_start\n",
    "    'Relative_relatedness_start'  # Main effect of Relative_relatedness_start\n",
    "])\n",
    "\n",
    "\n",
    "stargazer.rename_covariates({\n",
    "    f'RCA_mid': f'log of RCA ({mid_year})',\n",
    "    f'RCA_start': f'log of RCA ({start_year})',\n",
    "    f'Relatedness_start': f'Relatedness ({start_year})',\n",
    "    f'Relative_relatedness_start': f'Relative Relatedness ({start_year})'\n",
    "})\n",
    "\n",
    "stargazer.model_names = False\n",
    "stargazer.custom_columns([\n",
    "    \"Entry (RCA)\", \n",
    "    \"Entry (Rel.)\", \n",
    "    \"Entry (All)\", \n",
    "    \"Exit (RCA)\", \n",
    "    \"Exit (Rel.)\", \n",
    "    \"Exit (All)\"\n",
    "], [1, 1, 1, 1, 1, 1])\n",
    "\n",
    "\n",
    "# Set the dependent variable label\n",
    "stargazer.dependent_variable = f'log(1 + RCA) ({end_year})'\n",
    "\n",
    "# Render the HTML\n",
    "html_output = stargazer.render_html()\n",
    "\n",
    "# Display the HTML in Jupyter Notebook\n",
    "display(HTML(html_output))\n",
    "\n",
    "# Save the HTML output to a file\n",
    "with open(\"table1-5year.html\", \"w\") as file:\n",
    "    file.write(html_output) \n",
    "    \n",
    "    \n",
    "### 10 YEAR\n",
    "\n",
    "\n",
    "start_year = 2002\n",
    "mid_year = 2012\n",
    "end_year = 2022\n",
    "\n",
    "# Read data for start_year\n",
    "filename = 'trade_data/bilateral_' + str(start_year) + '.csv'\n",
    "data_start = pd.read_csv(filename)\n",
    "\n",
    "\n",
    "X_start = data_start.pivot_table(values='value', index='exporter_id', columns='hs_code', \n",
    "                                 aggfunc=np.sum, fill_value=0)\n",
    "\n",
    "# Remove rows with sum less than 10^9\n",
    "X_start = X_start[X_start.sum(axis=1) >= 10**9]\n",
    "\n",
    "# Remove columns with sum less than 500,000\n",
    "X_start = X_start.loc[:, X_start.sum(axis=0) >= 500000]\n",
    "\n",
    "countries_start = X_start.index.tolist()\n",
    "categories_start = X_start.columns.tolist()\n",
    "\n",
    "# Read data for mid_year\n",
    "filename = 'trade_data/bilateral_' + str(mid_year) + '.csv'\n",
    "data_mid = pd.read_csv(filename)\n",
    "\n",
    "\n",
    "X_mid = data_mid.pivot_table(values='value', index='exporter_id', columns='hs_code', \n",
    "                             aggfunc=np.sum, fill_value=0)\n",
    "\n",
    "# Remove rows with sum less than 10^9\n",
    "X_mid = X_mid[X_mid.sum(axis=1) >= 10**9]\n",
    "\n",
    "# Remove columns with sum less than 500,000\n",
    "X_mid = X_mid.loc[:, X_mid.sum(axis=0) >= 500000]\n",
    "\n",
    "countries_mid = X_mid.index.tolist()\n",
    "categories_mid = X_mid.columns.tolist()\n",
    "\n",
    "# Read data for end_year\n",
    "filename = 'trade_data/bilateral_' + str(end_year) + '.csv'\n",
    "data_end = pd.read_csv(filename)\n",
    "\n",
    "\n",
    "X_end = data_end.pivot_table(values='value', index='exporter_id', columns='hs_code', \n",
    "                             aggfunc=np.sum, fill_value=0)\n",
    "\n",
    "# Remove rows with sum less than 10^9\n",
    "X_end = X_end[X_end.sum(axis=1) >= 10**9]\n",
    "\n",
    "# Remove columns with sum less than 500,000\n",
    "X_end = X_end.loc[:, X_end.sum(axis=0) >= 500000]\n",
    "\n",
    "countries_end = X_end.index.tolist()\n",
    "categories_end = X_end.columns.tolist()\n",
    "\n",
    "# Find the intersection of countries and categories\n",
    "countries_probit = np.intersect1d(countries_start, countries_mid)\n",
    "countries_probit = np.intersect1d(countries_probit, countries_end)\n",
    "categories_probit = np.intersect1d(categories_start, categories_mid)\n",
    "categories_probit = np.intersect1d(categories_probit, categories_end)\n",
    "\n",
    "\n",
    "# Function to find indices\n",
    "def find_indices(original, to_find):\n",
    "    return [original.index(item) for item in to_find if item in original]\n",
    "\n",
    "# Finding indices\n",
    "z_countries_start = find_indices(countries_start, countries_probit)\n",
    "z_countries_mid = find_indices(countries_mid, countries_probit)\n",
    "z_countries_end = find_indices(countries_end, countries_probit)\n",
    "\n",
    "z_categories_start = find_indices(categories_start, categories_probit)\n",
    "z_categories_mid = find_indices(categories_mid, categories_probit)\n",
    "z_categories_end = find_indices(categories_end, categories_probit)\n",
    "\n",
    "# Creating matrices for start, mid, and end years using .iloc\n",
    "X_mat_start = X_start.iloc[z_countries_start, z_categories_start]\n",
    "X_mat_mid = X_mid.iloc[z_countries_mid, z_categories_mid]\n",
    "X_mat_end = X_end.iloc[z_countries_end, z_categories_end]\n",
    "\n",
    "# Assuming rca and cplex_rank functions are already defined\n",
    "RCA_start = eciopt.rca(X_mat_start)\n",
    "RCA_mid = eciopt.rca(X_mat_mid)\n",
    "RCA_end = eciopt.rca(X_mat_end)\n",
    "\n",
    "M_start = (RCA_start > 1).astype(float)\n",
    "\n",
    "_, _, Relatedness_start, _ = eciopt.cplex_rank(M_start, countries_probit, categories_probit)\n",
    "\n",
    "\n",
    "# Flatten the Relatedness_start matrix\n",
    "Relatedness_start = Relatedness_start.flatten()\n",
    "\n",
    "# Create repeated arrays for countries and products\n",
    "countries_all = np.repeat(countries_probit, len(categories_probit))\n",
    "products_all = np.tile(categories_probit, len(countries_probit))\n",
    "\n",
    "RR = RCA_start.flatten()\n",
    "RR_mid = RCA_mid.flatten()\n",
    "RR_end = RCA_end.flatten()\n",
    "\n",
    "# Create a DataFrame for analysis\n",
    "probit_data = pd.DataFrame({\n",
    "    'countries_all': countries_all,\n",
    "    'products_all': products_all,\n",
    "    'Relatedness_start': Relatedness_start,\n",
    "    'RCA_start': np.log(1 + RR),\n",
    "    'RCA_mid': np.log(1 + RR_mid),\n",
    "    'RCA_end': np.log(1 + RR_end)\n",
    "})\n",
    "\n",
    "# Split the data into 'entry' and 'exit' subsets\n",
    "entry_data = probit_data[probit_data['RCA_start'] < np.log(2)].copy()\n",
    "exit_data = probit_data[probit_data['RCA_start'] >= np.log(2)].copy()\n",
    "\n",
    "\n",
    "# Calculate the z-score of Relatedness_start separately for each subset\n",
    "entry_data['Relative_relatedness_start'] = entry_data.groupby('countries_all')['Relatedness_start'].transform(zscore)\n",
    "exit_data['Relative_relatedness_start'] = exit_data.groupby('countries_all')['Relatedness_start'].transform(zscore)\n",
    "\n",
    "# Define the model formula\n",
    "formula = 'RCA_end ~ RCA_mid + RCA_start + Relatedness_start + Relative_relatedness_start'\n",
    "\n",
    "formula_RCA = 'RCA_end ~ RCA_mid + RCA_start'\n",
    "formula_Relatedness = 'RCA_end ~ Relatedness_start + Relative_relatedness_start'\n",
    "formula_all = 'RCA_end ~ RCA_mid + RCA_start + Relatedness_start + Relative_relatedness_start'\n",
    "\n",
    "# Fit the models\n",
    "model_entry_RCA = ols(formula=formula_RCA, data=entry_data).fit()\n",
    "model_entry_Relatedness = ols(formula=formula_Relatedness, data=entry_data).fit()\n",
    "model_entry_all = ols(formula=formula_all, data=entry_data).fit()\n",
    "\n",
    "model_exit_RCA = ols(formula=formula_RCA, data=exit_data).fit()\n",
    "model_exit_Relatedness = ols(formula=formula_Relatedness, data=exit_data).fit()\n",
    "model_exit_all = ols(formula=formula_all, data=exit_data).fit()\n",
    "\n",
    "# Create a Stargazer object\n",
    "\n",
    "stargazer = Stargazer([model_entry_RCA, model_entry_Relatedness, model_entry_all, model_exit_RCA, model_exit_Relatedness, model_exit_all])\n",
    "\n",
    "# Covariate order and renaming\n",
    "stargazer.covariate_order([\n",
    "    'RCA_mid',                # Main effect of RCA_mid\n",
    "    'RCA_start',              # Main effect of RCA_start\n",
    "    'Relatedness_start',      # Main effect of Relatedness_start\n",
    "    'Relative_relatedness_start'  # Main effect of Relative_relatedness_start\n",
    "])\n",
    "\n",
    "\n",
    "stargazer.rename_covariates({\n",
    "    f'RCA_mid': f'log of RCA ({mid_year})',\n",
    "    f'RCA_start': f'log of RCA ({start_year})',\n",
    "    f'Relatedness_start': f'Relatedness ({start_year})',\n",
    "    f'Relative_relatedness_start': f'Relative Relatedness ({start_year})'\n",
    "})\n",
    "\n",
    "stargazer.model_names = False\n",
    "stargazer.custom_columns([\n",
    "    \"Entry (RCA)\", \n",
    "    \"Entry (Rel.)\", \n",
    "    \"Entry (All)\", \n",
    "    \"Exit (RCA)\", \n",
    "    \"Exit (Rel.)\", \n",
    "    \"Exit (All)\"\n",
    "], [1, 1, 1, 1, 1, 1])\n",
    "\n",
    "\n",
    "# Set the dependent variable label\n",
    "stargazer.dependent_variable = f'log(1 + RCA) ({end_year})'\n",
    "\n",
    "# Render the HTML\n",
    "html_output = stargazer.render_html()\n",
    "\n",
    "# Display the HTML in Jupyter Notebook\n",
    "display(HTML(html_output))\n",
    "\n",
    "# Save the HTML output to a file\n",
    "with open(\"table1-10year.html\", \"w\") as file:\n",
    "    file.write(html_output) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c924942",
   "metadata": {},
   "source": [
    "# ADDITIONAL CALCULATIONS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9592a38",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Using colors from a nice palette (tab10)\n",
    "color_pred = [0, 0.4470, 0.7410]  # Blue\n",
    "color_baseline = [0.9290, 0.6940, 0.1250]   # Yellow\n",
    "\n",
    "# Define markers for each dataset\n",
    "marker_baseline = '^'    # Upward triangle for Relatedness\n",
    "marker_pred = 's'   # Square for ECI Optimization\n",
    "\n",
    "\n",
    "end_year = 2022\n",
    "\n",
    "\n",
    "## NATIONAL DATA\n",
    "filename = 'trade_data/bilateral_' + str(end_year) +'.csv'\n",
    "data_end = pd.read_csv(filename)\n",
    "\n",
    "X_end = data_end.pivot_table(values='value', index='exporter_id', columns='hs_code', \n",
    "                             aggfunc=np.sum, fill_value=0)  \n",
    "\n",
    "# Remove rows with sum less than 10^9\n",
    "X_end = X_end[X_end.sum(axis=1) >= 10**9]\n",
    "\n",
    "# Remove columns with sum less than 500,000\n",
    "X_end = X_end.loc[:, X_end.sum(axis=0) >= 500000]\n",
    "\n",
    "countries_probit = X_end.index.tolist()\n",
    "products_probit = X_end.columns.tolist()\n",
    "\n",
    "# Function to find indices\n",
    "def find_indices(original, to_find):\n",
    "    return [original.index(item) for item in to_find if item in original]\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "X_countries = X_end.copy().to_numpy()\n",
    "\n",
    "\n",
    "# Assuming rca and cplex_rank functions are already defined in Python\n",
    "RCA_countries = eciopt.rca(X_countries)\n",
    "\n",
    "M_end = (RCA_countries > 1).astype(float)\n",
    "\n",
    "# Flatten and process other matrices\n",
    "CountryRankings_2022, ProductRankings_2022, Relatedness_countries_2022, _ = eciopt.cplex_rank(M_end, countries_probit, products_probit)\n",
    "\n",
    "pci_2022 = ProductRankings_2022['PCI'].values\n",
    "\n",
    "\n",
    "# Step 1: Matrix-vector multiplication\n",
    "# Ensure M is a numpy array and pci is appropriately shaped for the multiplication\n",
    "product_vals = M_end @ pci_2022.reshape(-1, 1)  # pci reshaped to a column vector if not already\n",
    "\n",
    "# Step 2: Calculate row-wise sum of M\n",
    "row_sums = M_end.sum(axis=1).reshape(-1, 1)  # Reshape for compatibility in division\n",
    "\n",
    "# Step 3: Element-wise division\n",
    "normalized_product = product_vals / row_sums\n",
    "\n",
    "# Since normalized_product might be a 2D array with a single column, flatten it if you're assigning back to a DataFrame column\n",
    "CountryRankings_2022['ECI_not_normalized'] = normalized_product.flatten()\n",
    "\n",
    "\n",
    "# Create repeated arrays for countries and products\n",
    "countries_all = np.repeat(countries_probit, len(products_probit))\n",
    "products_all = np.tile(products_probit, len(countries_probit))\n",
    "\n",
    "\n",
    "data_optimization_countries = pd.DataFrame({\n",
    "    'countries_all': countries_all,\n",
    "    'products_all': products_all,\n",
    "    'Relatedness_start': Relatedness_countries_2022.flatten(),\n",
    "    'RCA_start': np.log(1+RCA_countries.flatten()),\n",
    "    'RCA_mid': np.log(1+RCA_countries.flatten()),\n",
    "    'RCA_end': np.log(1+RCA_countries.flatten())\n",
    "})\n",
    "\n",
    "optimal_threshold = 1\n",
    "\n",
    "# Split the data into two subsets based on the RCA_start condition and label them as 'entry' and 'exit'\n",
    "entry_data = data_optimization_countries[data_optimization_countries['RCA_start'] < np.log(2)].copy()\n",
    "exit_data = data_optimization_countries[data_optimization_countries['RCA_start'] >= np.log(2)].copy()\n",
    "\n",
    "# Calculate the z-score of Relatedness_start separately for each subset\n",
    "entry_data['Relative_relatedness_start'] = entry_data.groupby('countries_all')['Relatedness_start'].transform(zscore)\n",
    "exit_data['Relative_relatedness_start'] = exit_data.groupby('countries_all')['Relatedness_start'].transform(zscore)\n",
    "\n",
    "entry_data['predicted_prob'] = np.expm1(beta_country_entry['beta_Intercept'] + beta_country_entry['beta_RCA_mid']*entry_data['RCA_mid'] + beta_country_entry['beta_RCA_start']*entry_data['RCA_start'] + beta_country_entry['beta_Relatedness_start']*entry_data['Relatedness_start'] \n",
    "+ beta_country_entry['beta_Relative_relatedness_start']*entry_data['Relative_relatedness_start']) \n",
    "\n",
    "exit_data['predicted_prob'] = np.expm1(beta_country_exit['beta_Intercept'] + beta_country_exit['beta_RCA_mid']*exit_data['RCA_mid'] + beta_country_exit['beta_RCA_start']*exit_data['RCA_start'] + beta_country_exit['beta_Relatedness_start']*exit_data['Relatedness_start'] \n",
    "+ beta_country_exit['beta_Relative_relatedness_start']*exit_data['Relative_relatedness_start']) \n",
    "\n",
    "# Place the predicted probabilities back into the original probit_data\n",
    "data_optimization_countries.loc[data_optimization_countries['RCA_start'] < np.log(2), 'predicted_prob'] = entry_data['predicted_prob']\n",
    "data_optimization_countries.loc[data_optimization_countries['RCA_start'] >= np.log(2), 'predicted_prob'] = exit_data['predicted_prob']\n",
    "\n",
    "data_optimization_countries.loc[data_optimization_countries['RCA_start'] < np.log(2), 'Relative_relatedness_start'] = entry_data['Relative_relatedness_start']\n",
    "data_optimization_countries.loc[data_optimization_countries['RCA_start'] >= np.log(2), 'Relative_relatedness_start'] = exit_data['Relative_relatedness_start']\n",
    "\n",
    "# Pivot the DataFrame to create the matrix\n",
    "predicted_country_matrix = data_optimization_countries.pivot(index='countries_all', \n",
    "                                        columns='products_all', \n",
    "                                        values='predicted_prob')\n",
    "\n",
    "# Reindex the pivoted DataFrame to match the original order of countries and products\n",
    "predicted_country_matrix = predicted_country_matrix.reindex(index=countries_probit, columns=products_probit)\n",
    "\n",
    "M_countries = np.array(predicted_country_matrix > optimal_threshold, dtype=float)\n",
    "\n",
    "## similarity matrix\n",
    "\n",
    "# Ubiquity and diversity\n",
    "Kp0 = M_countries.sum(axis=0)\n",
    "Kc0 = M_countries.sum(axis=1)\n",
    "\n",
    "# Calculate proximity of products (PHIpp)\n",
    "PHIpp_country = np.zeros((len(Kp0), len(Kp0)))\n",
    "for i in range(len(Kp0)):\n",
    "    for j in range(len(Kp0)):\n",
    "        PHIpp_country[i, j] = np.dot(M_countries[:, i], M_countries[:, j]) / max(Kp0[i], Kp0[j])\n",
    "\n",
    "probit_country_data = data_optimization_countries.copy()\n",
    "\n",
    "CountryRankings, ProductRankings, _, _ = eciopt.cplex_rank(M_countries, countries_probit, products_probit)\n",
    "\n",
    "CountryRankings['ECI_2022'] = CountryRankings_2022['ECI'].copy()\n",
    "CountryRankings['ECI_not_normalized_2022'] = CountryRankings_2022['ECI_not_normalized'].copy()\n",
    "\n",
    "\n",
    "ProductRankings['PCI_2022'] = ProductRankings_2022['PCI'].copy()\n",
    "pci = ProductRankings['PCI'].values\n",
    "\n",
    "\n",
    "# Step 1: Matrix-vector multiplication\n",
    "# Ensure M is a numpy array and pci is appropriately shaped for the multiplication\n",
    "product_vals = M_countries @ pci.reshape(-1, 1)  # pci reshaped to a column vector if not already\n",
    "\n",
    "# Step 2: Calculate row-wise sum of M\n",
    "row_sums = M_countries.sum(axis=1).reshape(-1, 1)  # Reshape for compatibility in division\n",
    "\n",
    "# Step 3: Element-wise division\n",
    "normalized_product = product_vals / row_sums\n",
    "\n",
    "# Since normalized_product might be a 2D array with a single column, flatten it if you're assigning back to a DataFrame column\n",
    "CountryRankings['ECI_not_normalized'] = normalized_product.flatten()\n",
    "sd_for_countries = np.std(normalized_product)\n",
    "mean_for_countries = np.mean(normalized_product)\n",
    "\n",
    "probit_countries_data = data_optimization_countries.copy()\n",
    "\n",
    "## SUBNATIONAL DATA\n",
    "\n",
    "filename = 'msa_data/msa_data_' + str(end_year) +'.csv'\n",
    "data_end = pd.read_csv(filename)\n",
    "\n",
    "X_end = data_end.pivot_table(values='ap_avg', index='msa', columns='naics', \n",
    "                                      aggfunc=np.sum, fill_value=0)  \n",
    "\n",
    "\n",
    "# Remove rows with sum less than 0\n",
    "X_end = X_end[X_end.sum(axis=1) > 10**5]\n",
    "# Remove columns with sum less than 0\n",
    "X_end = X_end.loc[:, X_end.sum(axis=0) > 10**5.5]\n",
    "\n",
    "\n",
    "msa_probit = X_end.index.tolist()\n",
    "activities_probit = X_end.columns.tolist()\n",
    "\n",
    "X_msa = X_end.copy().to_numpy() \n",
    "\n",
    "\n",
    "RCA_msa = eciopt.rca(X_msa)\n",
    "\n",
    "M_end = (RCA_msa > 1).astype(float)\n",
    "\n",
    "# Flatten and process other matrices\n",
    "msaRankings_2022, NaicsRankings_2022, Relatedness_msa_2022, _ = eciopt.cplex_rank(M_end, msa_probit, activities_probit)\n",
    "\n",
    "pci_2022 = NaicsRankings_2022['PCI'].values\n",
    "\n",
    "# Step 1: Matrix-vector multiplication\n",
    "# Ensure M is a numpy array and pci is appropriately shaped for the multiplication\n",
    "product_vals = M_end @ pci_2022.reshape(-1, 1)  # pci reshaped to a column vector if not already\n",
    "\n",
    "# Step 2: Calculate row-wise sum of M\n",
    "row_sums = M_end.sum(axis=1).reshape(-1, 1)  # Reshape for compatibility in division\n",
    "\n",
    "# Step 3: Element-wise division\n",
    "normalized_product = product_vals / row_sums\n",
    "\n",
    "# Since normalized_product might be a 2D array with a single column, flatten it if you're assigning back to a DataFrame column\n",
    "msaRankings_2022['ECI_not_normalized'] = normalized_product.flatten()\n",
    "\n",
    "\n",
    "# Create repeated arrays for countries and products\n",
    "msa_all = np.repeat(msa_probit, len(activities_probit))\n",
    "activities_all = np.tile(activities_probit, len(msa_probit))\n",
    "\n",
    "\n",
    "data_optimization_msa = pd.DataFrame({\n",
    "    'msa_all': msa_all,\n",
    "    'products_all': activities_all,\n",
    "    'Relatedness_start': Relatedness_msa_2022.flatten(),\n",
    "    'RCA_start': np.log(1+RCA_msa.flatten()),\n",
    "    'RCA_mid': np.log(1+RCA_msa.flatten()),\n",
    "    'RCA_end': np.log(1+RCA_msa.flatten())\n",
    "})\n",
    "\n",
    "# Split the data into two subsets based on the RCA_start condition and label them as 'entry' and 'exit'\n",
    "entry_data = data_optimization_msa[data_optimization_msa['RCA_start'] < np.log(2)].copy()\n",
    "exit_data = data_optimization_msa[data_optimization_msa['RCA_start'] >= np.log(2)].copy()\n",
    "\n",
    "# Calculate the z-score of Relatedness_start separately for each subset\n",
    "entry_data['Relative_relatedness_start'] = entry_data.groupby('msa_all')['Relatedness_start'].transform(zscore)\n",
    "exit_data['Relative_relatedness_start'] = exit_data.groupby('msa_all')['Relatedness_start'].transform(zscore)\n",
    "\n",
    "entry_data['predicted_prob'] = np.expm1(beta_msa_entry['beta_Intercept'] + beta_msa_entry['beta_RCA_mid']*entry_data['RCA_mid'] + beta_msa_entry['beta_RCA_start']*entry_data['RCA_start'] + beta_msa_entry['beta_Relatedness_start']*entry_data['Relatedness_start'] \n",
    "+ beta_msa_entry['beta_Relative_relatedness_start']*entry_data['Relative_relatedness_start']) \n",
    "\n",
    "exit_data['predicted_prob'] = np.expm1(beta_msa_exit['beta_Intercept'] + beta_msa_exit['beta_RCA_mid']*exit_data['RCA_mid'] + beta_msa_exit['beta_RCA_start']*exit_data['RCA_start'] + beta_msa_exit['beta_Relatedness_start']*exit_data['Relatedness_start'] \n",
    "+ beta_msa_exit['beta_Relative_relatedness_start']*exit_data['Relative_relatedness_start']) \n",
    "\n",
    "# Place the predicted probabilities back into the original probit_data\n",
    "data_optimization_msa.loc[data_optimization_msa['RCA_start'] < np.log(2), 'predicted_prob'] = entry_data['predicted_prob']\n",
    "data_optimization_msa.loc[data_optimization_msa['RCA_start'] >= np.log(2), 'predicted_prob'] = exit_data['predicted_prob']\n",
    "\n",
    "data_optimization_msa.loc[data_optimization_msa['RCA_start'] < np.log(2), 'Relative_relatedness_start'] = entry_data['Relative_relatedness_start']\n",
    "data_optimization_msa.loc[data_optimization_msa['RCA_start'] >= np.log(2), 'Relative_relatedness_start'] = exit_data['Relative_relatedness_start']\n",
    "\n",
    "# Pivot the DataFrame to create the matrix\n",
    "predicted_msa_matrix = data_optimization_msa.pivot(index='msa_all', \n",
    "                                        columns='products_all', \n",
    "                                        values='predicted_prob')\n",
    "\n",
    "# Reindex the pivoted DataFrame to match the original order of msa and products\n",
    "predicted_msa_matrix = predicted_msa_matrix.reindex(index=msa_probit, columns=activities_probit)\n",
    "\n",
    "M_msa = np.array(predicted_msa_matrix > optimal_threshold, dtype=float)\n",
    "\n",
    "## similarity matrix\n",
    "\n",
    "# Ubiquity and diversity\n",
    "Kp0 = M_msa.sum(axis=0)\n",
    "Kc0 = M_msa.sum(axis=1)\n",
    "\n",
    "# Calculate proximity of products (PHIpp)\n",
    "PHIpp_msa = np.zeros((len(Kp0), len(Kp0)))\n",
    "for i in range(len(Kp0)):\n",
    "    for j in range(len(Kp0)):\n",
    "        PHIpp_msa[i, j] = np.dot(M_msa[:, i], M_msa[:, j]) / max(Kp0[i], Kp0[j])\n",
    "\n",
    "probit_msa_data = data_optimization_msa.copy()\n",
    "\n",
    "msaRankings, NaicsRankings, _, _ = eciopt.cplex_rank(M_msa, msa_probit, activities_probit)\n",
    "\n",
    "msaRankings['ECI_2022'] = msaRankings_2022['ECI'].copy()\n",
    "msaRankings['ECI_not_normalized_2022'] = msaRankings_2022['ECI_not_normalized'].copy()\n",
    "\n",
    "NaicsRankings['PCI_2022'] = NaicsRankings['PCI'].copy()\n",
    "pci = NaicsRankings['PCI'].values\n",
    "\n",
    "\n",
    "# Step 1: Matrix-vector multiplication\n",
    "# Ensure M is a numpy array and pci is appropriately shaped for the multiplication\n",
    "product_vals = M_msa @ pci.reshape(-1, 1)  # pci reshaped to a column vector if not already\n",
    "\n",
    "# Step 2: Calculate row-wise sum of M\n",
    "row_sums = M_msa.sum(axis=1).reshape(-1, 1)  # Reshape for compatibility in division\n",
    "\n",
    "# Step 3: Element-wise division\n",
    "normalized_product = product_vals / row_sums\n",
    "# Since normalized_product might be a 2D array with a single column, flatten it if you're assigning back to a DataFrame column\n",
    "msaRankings['ECI_not_normalized'] = normalized_product.flatten()\n",
    "sd_for_msa = np.std(normalized_product)\n",
    "mean_for_msa = np.mean(normalized_product)\n",
    "\n",
    "probit_msa_data = data_optimization_msa.copy()\n",
    "\n",
    "\n",
    "end_year = 2021\n",
    "\n",
    "## PATENTS DATA\n",
    "optimal_threshold = 1\n",
    "\n",
    "filename = 'pct_data/pct_data_' + str(end_year) +'.csv'\n",
    "data_end = pd.read_csv(filename)\n",
    "\n",
    "X_end = data_end.set_index(\"Row\")\n",
    "\n",
    "pct_probit = X_end.index.tolist()\n",
    "patents_probit = X_end.columns.tolist()\n",
    "\n",
    "X_pct = X_end.copy().to_numpy() \n",
    "\n",
    "\n",
    "RCA_pct = eciopt.rca(X_pct)\n",
    "\n",
    "M_end = (RCA_pct > 1).astype(float)\n",
    "\n",
    "# Flatten and process other matrices\n",
    "pctRankings_2022, PCTRankings_2022, Relatedness_pct_2022, _ = eciopt.cplex_rank(M_end, pct_probit, patents_probit)\n",
    "\n",
    "pci_2022 = PCTRankings_2022['PCI'].values\n",
    "\n",
    "# Step 1: Matrix-vector multiplication\n",
    "# Ensure M is a numpy array and pci is appropriately shaped for the multiplication\n",
    "product_vals = M_end @ pci_2022.reshape(-1, 1)  # pci reshaped to a column vector if not already\n",
    "\n",
    "# Step 2: Calculate row-wise sum of M\n",
    "row_sums = M_end.sum(axis=1).reshape(-1, 1)  # Reshape for compatibility in division\n",
    "\n",
    "# Step 3: Element-wise division\n",
    "normalized_product = product_vals / row_sums\n",
    "\n",
    "# Since normalized_product might be a 2D array with a single column, flatten it if you're assigning back to a DataFrame column\n",
    "pctRankings_2022['ECI_not_normalized'] = normalized_product.flatten()\n",
    "\n",
    "\n",
    "# Create repeated arrays for countries and products\n",
    "pct_all = np.repeat(pct_probit, len(patents_probit))\n",
    "patents_all = np.tile(patents_probit, len(pct_probit))\n",
    "\n",
    "\n",
    "data_optimization_pct = pd.DataFrame({\n",
    "    'pct_all': pct_all,\n",
    "    'products_all': patents_all,\n",
    "    'Relatedness_start': Relatedness_pct_2022.flatten(),\n",
    "    'RCA_start': np.log(1+RCA_pct.flatten()),\n",
    "    'RCA_mid': np.log(1+RCA_pct.flatten()),\n",
    "    'RCA_end': np.log(1+RCA_pct.flatten())\n",
    "})\n",
    "\n",
    "# Split the data into two subsets based on the RCA_start condition and label them as 'entry' and 'exit'\n",
    "entry_data = data_optimization_pct[data_optimization_pct['RCA_start'] < np.log(2)].copy()\n",
    "exit_data = data_optimization_pct[data_optimization_pct['RCA_start'] >= np.log(2)].copy()\n",
    "\n",
    "# Calculate the z-score of Relatedness_start separately for each subset\n",
    "entry_data['Relative_relatedness_start'] = entry_data.groupby('pct_all')['Relatedness_start'].transform(zscore)\n",
    "exit_data['Relative_relatedness_start'] = exit_data.groupby('pct_all')['Relatedness_start'].transform(zscore)\n",
    "\n",
    "entry_data['predicted_prob'] = np.expm1(beta_pct_entry['beta_Intercept'] + beta_pct_entry['beta_RCA_mid']*entry_data['RCA_mid'] + beta_pct_entry['beta_RCA_start']*entry_data['RCA_start'] + beta_pct_entry['beta_Relatedness_start']*entry_data['Relatedness_start'] \n",
    "+ beta_pct_entry['beta_Relative_relatedness_start']*entry_data['Relative_relatedness_start']) \n",
    "\n",
    "exit_data['predicted_prob'] = np.expm1(beta_pct_exit['beta_Intercept'] + beta_pct_exit['beta_RCA_mid']*exit_data['RCA_mid'] + beta_pct_exit['beta_RCA_start']*exit_data['RCA_start'] + beta_pct_exit['beta_Relatedness_start']*exit_data['Relatedness_start'] \n",
    "+ beta_pct_exit['beta_Relative_relatedness_start']*exit_data['Relative_relatedness_start']) \n",
    "\n",
    "# Place the predicted probabilities back into the original probit_data\n",
    "data_optimization_pct.loc[data_optimization_pct['RCA_start'] < np.log(2), 'predicted_prob'] = entry_data['predicted_prob']\n",
    "data_optimization_pct.loc[data_optimization_pct['RCA_start'] >= np.log(2), 'predicted_prob'] = exit_data['predicted_prob']\n",
    "\n",
    "data_optimization_pct.loc[data_optimization_pct['RCA_start'] < np.log(2), 'Relative_relatedness_start'] = entry_data['Relative_relatedness_start']\n",
    "data_optimization_pct.loc[data_optimization_pct['RCA_start'] >= np.log(2), 'Relative_relatedness_start'] = exit_data['Relative_relatedness_start']\n",
    "\n",
    "# Pivot the DataFrame to create the matrix\n",
    "predicted_pct_matrix = data_optimization_pct.pivot(index='pct_all', \n",
    "                                        columns='products_all', \n",
    "                                        values='predicted_prob')\n",
    "\n",
    "# Reindex the pivoted DataFrame to match the original order of pct and products\n",
    "predicted_pct_matrix = predicted_pct_matrix.reindex(index=pct_probit, columns=patents_probit)\n",
    "\n",
    "M_pct = np.array(predicted_pct_matrix > optimal_threshold, dtype=float)\n",
    "\n",
    "## similarity matrix\n",
    "\n",
    "# Ubiquity and diversity\n",
    "Kp0 = M_pct.sum(axis=0)\n",
    "Kc0 = M_pct.sum(axis=1)\n",
    "\n",
    "# Calculate proximity of products (PHIpp)\n",
    "PHIpp_pct = np.zeros((len(Kp0), len(Kp0)))\n",
    "for i in range(len(Kp0)):\n",
    "    for j in range(len(Kp0)):\n",
    "        PHIpp_pct[i, j] = np.dot(M_pct[:, i], M_pct[:, j]) / max(Kp0[i], Kp0[j])\n",
    "\n",
    "probit_pct_data = data_optimization_pct.copy()\n",
    "\n",
    "pctRankings, PCTRankings, _, _ = eciopt.cplex_rank(M_pct, pct_probit, patents_probit)\n",
    "\n",
    "pctRankings['ECI_2022'] = pctRankings_2022['ECI'].copy()\n",
    "pctRankings['ECI_not_normalized_2022'] = pctRankings_2022['ECI_not_normalized'].copy()\n",
    "\n",
    "PCTRankings['PCI_2022'] = PCTRankings['PCI'].copy()\n",
    "pci = PCTRankings['PCI'].values\n",
    "\n",
    "\n",
    "# Step 1: Matrix-vector multiplication\n",
    "# Ensure M is a numpy array and pci is appropriately shaped for the multiplication\n",
    "product_vals = M_pct @ pci.reshape(-1, 1)  # pci reshaped to a column vector if not already\n",
    "\n",
    "# Step 2: Calculate row-wise sum of M\n",
    "row_sums = M_pct.sum(axis=1).reshape(-1, 1)  # Reshape for compatibility in division\n",
    "\n",
    "# Step 3: Element-wise division\n",
    "normalized_product = product_vals / row_sums\n",
    "# Since normalized_product might be a 2D array with a single column, flatten it if you're assigning back to a DataFrame column\n",
    "pctRankings['ECI_not_normalized'] = normalized_product.flatten()\n",
    "sd_for_pct = np.std(normalized_product)\n",
    "mean_for_pct = np.mean(normalized_product)\n",
    "\n",
    "probit_pct_data = data_optimization_pct.copy()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38347740-cb31-49be-b40d-300a140c4645",
   "metadata": {},
   "source": [
    "## CALCULATE THRESHOLDS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcf452c0-1433-4dfa-9c5f-233ab41702e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.linear_model import LinearRegression\n",
    "\n",
    "# Set ECI bins\n",
    "eci_range_values = np.arange(-2.5, 2.5, 0.2)\n",
    "\n",
    "# Map ProductIndex to PCI_2022\n",
    "product_pci = dict(zip(ProductRankings['Product'], ProductRankings['PCI_2022']))\n",
    "product_index_to_pci = [product_pci.get(p, np.nan) for p in products_probit]\n",
    "\n",
    "# Store data for table and plot\n",
    "eci_bin_data = []\n",
    "\n",
    "for eci_min in eci_range_values:\n",
    "    eci_max = eci_min + 0.2\n",
    "    eci_mid = (eci_min + eci_max) / 2\n",
    "\n",
    "    # Get countries in this ECI bin\n",
    "    selected_mask = (CountryRankings['ECI'] >= eci_min) & (CountryRankings['ECI'] < eci_max)\n",
    "    selected_countries = CountryRankings[selected_mask]['Country'].tolist()\n",
    "    selected_indices = [i for i, c in enumerate(countries_probit) if c in selected_countries]\n",
    "\n",
    "    if len(selected_indices) == 0:\n",
    "        continue\n",
    "\n",
    "    # Matrix slice: rows = countries, columns = products\n",
    "    M_sub = M_countries[selected_indices, :]\n",
    "    product_selected_mask = M_sub.sum(axis=0) > 0\n",
    "    pci_values = np.array(product_index_to_pci)[product_selected_mask]\n",
    "    pci_values = pci_values[~np.isnan(pci_values)]\n",
    "\n",
    "    if len(pci_values) < 4:\n",
    "        continue\n",
    "\n",
    "    # Calculate boxplot upper whisker\n",
    "    q1 = np.percentile(pci_values, 25)\n",
    "    q3 = np.percentile(pci_values, 75)\n",
    "    iqr = q3 - q1\n",
    "    upper_whisker = q3 + 1 * iqr\n",
    "\n",
    "    # Append to list\n",
    "    eci_bin_data.append({\n",
    "        'ECI_Lower': round(eci_min, 2),\n",
    "        'ECI_Upper': round(eci_max, 2),\n",
    "        'ECI_Mid': round(eci_mid, 2),\n",
    "        'PCI_UpperWhisker': round(upper_whisker, 3)\n",
    "    })\n",
    "\n",
    "# Convert to DataFrame\n",
    "eci_pci_whisker_df = pd.DataFrame(eci_bin_data)\n",
    "\n",
    "# ---- Plotting ----\n",
    "X = np.array(eci_pci_whisker_df['ECI_Mid']).reshape(-1, 1)\n",
    "y = np.array(eci_pci_whisker_df['PCI_UpperWhisker'])\n",
    "\n",
    "# Fit linear regression\n",
    "reg = LinearRegression()\n",
    "reg.fit(X, y)\n",
    "y_pred = reg.predict(X)\n",
    "\n",
    "# Plot\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.plot(eci_pci_whisker_df['ECI_Mid'], y, marker='o', label='Upper Whisker of PCI')\n",
    "plt.plot(eci_pci_whisker_df['ECI_Mid'], y_pred, linestyle='--', color='red',\n",
    "         label=f'Least Squares Fit\\n(y = {reg.coef_[0]:.3f}x + {reg.intercept_:.3f})')\n",
    "\n",
    "plt.xlabel('ECI Range Midpoint')\n",
    "plt.ylabel('Max PCI of Specializations (Boxplot Whisker)')\n",
    "plt.title('Upper Bound of Specialized Product Complexity by ECI Group')\n",
    "plt.legend()\n",
    "plt.grid(True)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# Set ECI bins\n",
    "eci_range_values = np.arange(-2.5, 2.5, 0.2)\n",
    "\n",
    "# Map ProductIndex to PCI_2022\n",
    "product_pci = dict(zip(NaicsRankings['Product'], NaicsRankings['PCI_2022']))\n",
    "product_index_to_pci = [product_pci.get(p, np.nan) for p in activities_probit]\n",
    "\n",
    "# Store data for table and plot\n",
    "eci_bin_data = []\n",
    "\n",
    "for eci_min in eci_range_values:\n",
    "    eci_max = eci_min + 0.2\n",
    "    eci_mid = (eci_min + eci_max) / 2\n",
    "\n",
    "    # Get countries in this ECI bin\n",
    "    selected_mask = (msaRankings['ECI'] >= eci_min) & (msaRankings['ECI'] < eci_max)\n",
    "    selected_countries = msaRankings[selected_mask]['Country'].tolist()\n",
    "    selected_indices = [i for i, c in enumerate(msa_probit) if c in selected_countries]\n",
    "\n",
    "    if len(selected_indices) == 0:\n",
    "        continue\n",
    "\n",
    "    # Matrix slice: rows = countries, columns = products\n",
    "    M_sub = M_msa[selected_indices, :]\n",
    "    product_selected_mask = M_sub.sum(axis=0) > 0\n",
    "    pci_values = np.array(product_index_to_pci)[product_selected_mask]\n",
    "    pci_values = pci_values[~np.isnan(pci_values)]\n",
    "\n",
    "    if len(pci_values) < 4:\n",
    "        continue\n",
    "\n",
    "    # Calculate boxplot upper whisker\n",
    "    q1 = np.percentile(pci_values, 25)\n",
    "    q3 = np.percentile(pci_values, 75)\n",
    "    iqr = q3 - q1\n",
    "    upper_whisker = q3 + 1.5 * iqr\n",
    "\n",
    "    # Append to list\n",
    "    eci_bin_data.append({\n",
    "        'ECI_Lower': round(eci_min, 2),\n",
    "        'ECI_Upper': round(eci_max, 2),\n",
    "        'ECI_Mid': round(eci_mid, 2),\n",
    "        'PCI_UpperWhisker': round(upper_whisker, 3)\n",
    "    })\n",
    "\n",
    "# Convert to DataFrame\n",
    "msa_eci_pci_whisker_df = pd.DataFrame(eci_bin_data)\n",
    "\n",
    "# ---- Plotting ----\n",
    "X = np.array(msa_eci_pci_whisker_df['ECI_Mid']).reshape(-1, 1)\n",
    "y = np.array(msa_eci_pci_whisker_df['PCI_UpperWhisker'])\n",
    "\n",
    "# Fit linear regression\n",
    "reg = LinearRegression()\n",
    "reg.fit(X, y)\n",
    "y_pred = reg.predict(X)\n",
    "\n",
    "# Plot\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.plot(msa_eci_pci_whisker_df['ECI_Mid'], y, marker='o', label='Upper Whisker of PCI')\n",
    "plt.plot(msa_eci_pci_whisker_df['ECI_Mid'], y_pred, linestyle='--', color='red',\n",
    "         label=f'Least Squares Fit\\n(y = {reg.coef_[0]:.3f}x + {reg.intercept_:.3f})')\n",
    "\n",
    "plt.xlabel('ECI Range Midpoint')\n",
    "plt.ylabel('Max PCI of Specializations (Boxplot Whisker)')\n",
    "plt.title('Upper Bound of Specialized Activity Complexity by ECI Group')\n",
    "plt.legend()\n",
    "plt.grid(True)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# Patents\n",
    "# Set ECI bins\n",
    "eci_range_values = np.arange(-2.5, 2.5, 0.2)\n",
    "\n",
    "# Map ProductIndex to PCI_2022\n",
    "product_pci = dict(zip(PCTRankings['Product'], PCTRankings['PCI_2022']))\n",
    "product_index_to_pci = [product_pci.get(p, np.nan) for p in patents_probit]\n",
    "\n",
    "# Store data for table and plot\n",
    "eci_bin_data = []\n",
    "\n",
    "for eci_min in eci_range_values:\n",
    "    eci_max = eci_min + 0.2\n",
    "    eci_mid = (eci_min + eci_max) / 2\n",
    "\n",
    "    # Get countries in this ECI bin\n",
    "    selected_mask = (pctRankings['ECI'] >= eci_min) & (pctRankings['ECI'] < eci_max)\n",
    "    selected_countries = pctRankings[selected_mask]['Country'].tolist()\n",
    "    selected_indices = [i for i, c in enumerate(pct_probit) if c in selected_countries]\n",
    "\n",
    "    if len(selected_indices) == 0:\n",
    "        continue\n",
    "\n",
    "    # Matrix slice: rows = countries, columns = products\n",
    "    M_sub = M_pct[selected_indices, :]\n",
    "    product_selected_mask = M_sub.sum(axis=0) > 0\n",
    "    pci_values = np.array(product_index_to_pci)[product_selected_mask]\n",
    "    pci_values = pci_values[~np.isnan(pci_values)]\n",
    "\n",
    "    if len(pci_values) < 4:\n",
    "        continue\n",
    "\n",
    "    # Calculate boxplot upper whisker\n",
    "    q1 = np.percentile(pci_values, 25)\n",
    "    q3 = np.percentile(pci_values, 75)\n",
    "    iqr = q3 - q1\n",
    "    upper_whisker = q3 + 1 * iqr\n",
    "\n",
    "    # Append to list\n",
    "    eci_bin_data.append({\n",
    "        'ECI_Lower': round(eci_min, 2),\n",
    "        'ECI_Upper': round(eci_max, 2),\n",
    "        'ECI_Mid': round(eci_mid, 2),\n",
    "        'PCI_UpperWhisker': round(upper_whisker, 3)\n",
    "    })\n",
    "\n",
    "# Convert to DataFrame\n",
    "pct_eci_pci_whisker_df = pd.DataFrame(eci_bin_data)\n",
    "\n",
    "# ---- Plotting ----\n",
    "X = np.array(pct_eci_pci_whisker_df['ECI_Mid']).reshape(-1, 1)\n",
    "y = np.array(pct_eci_pci_whisker_df['PCI_UpperWhisker'])\n",
    "\n",
    "# Fit linear regression\n",
    "reg = LinearRegression()\n",
    "reg.fit(X, y)\n",
    "y_pred = reg.predict(X)\n",
    "\n",
    "# Plot\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.plot(pct_eci_pci_whisker_df['ECI_Mid'], y, marker='o', label='Upper Whisker of PCI')\n",
    "plt.plot(pct_eci_pci_whisker_df['ECI_Mid'], y_pred, linestyle='--', color='red',\n",
    "         label=f'Least Squares Fit\\n(y = {reg.coef_[0]:.3f}x + {reg.intercept_:.3f})')\n",
    "\n",
    "plt.xlabel('ECI Range Midpoint')\n",
    "plt.ylabel('Max PCI of Specializations (Boxplot Whisker)')\n",
    "plt.title('Upper Bound of Specialized Patents Complexity by ECI Group')\n",
    "plt.legend()\n",
    "plt.grid(True)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7727a2fb",
   "metadata": {},
   "source": [
    "# FIGURE 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f78f7bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.gridspec as gridspec\n",
    "import seaborn as sns\n",
    "from scipy.spatial import ConvexHull  # <-- added\n",
    "\n",
    "\n",
    "target_country = 'vnm'\n",
    "\n",
    "country_titles = {\n",
    "    'vnm': \"Vietnam\",\n",
    "}\n",
    "\n",
    "\n",
    "num_sizes = 5\n",
    "\n",
    "\n",
    "\n",
    "k = np.where(probit_country_data['countries_all'] == target_country)[0]\n",
    "\n",
    "# Check if target_country is a number (int or float)\n",
    "if isinstance(target_country, (int, float)):\n",
    "    X_start = X_countries[target_country - 1, :].T \n",
    "    M_start = M_countries[target_country - 1, :].T \n",
    "else:\n",
    "    country_index = np.where(CountryRankings['Country'] == target_country)[0]\n",
    "    X_start = X_countries[country_index[0], :].T\n",
    "    M_start = M_countries[country_index[0], :].T\n",
    "    \n",
    "X_c_start = np.sum(X_start)\n",
    "X_p_start = np.sum(X_countries, axis=0).T\n",
    "W_p = X_p_start / np.sum(X_p_start)\n",
    "\n",
    "Relatedness_start = probit_country_data['Relatedness_start']\n",
    "Relatedness_start = Relatedness_start[k]\n",
    "Relatedness_start = Relatedness_start.to_numpy()\n",
    "\n",
    "Relative_relatedness_start = probit_country_data['Relative_relatedness_start']\n",
    "Relative_relatedness_start = Relative_relatedness_start[k]\n",
    "Relative_relatedness_start = Relative_relatedness_start.to_numpy()\n",
    "\n",
    "predicted_prob = probit_country_data['predicted_prob']\n",
    "predicted_prob = predicted_prob[k]\n",
    "predicted_prob = predicted_prob.to_numpy()\n",
    "\n",
    "RCA_start = (X_start/X_c_start) / W_p\n",
    "\n",
    "ProductRankings['X_start'] = X_start\n",
    "ProductRankings['Relatedness_start'] = Relatedness_start\n",
    "ProductRankings['Relative_relatedness_start'] = Relative_relatedness_start  \n",
    "ProductRankings['predicted_prob'] = predicted_prob      \n",
    "ProductRankings['X_p_start'] = X_p_start\n",
    "ProductRankings['W_p'] = W_p\n",
    "ProductRankings['RCA_start'] = RCA_start\n",
    "ProductRankings['M_start'] = M_start\n",
    "\n",
    "ProductRankings['relative_relatedness'] = np.where(ProductRankings['RCA_start'] >= 1, np.nan, ProductRankings['Relatedness_start'])\n",
    "\n",
    "# Calculate the z-score for 'Relatedness_start', ignoring NaN values\n",
    "ProductRankings['relative_relatedness'] = zscore(ProductRankings['relative_relatedness'], nan_policy='omit')\n",
    "\n",
    "ProductRankings['world_exp'] = X_p_start\n",
    "ProductRankings['percentile'] = pd.qcut(ProductRankings['world_exp'], 5, labels=False) + 1\n",
    "\n",
    "RCA_start_entry = RCA_start[RCA_start < 1]\n",
    "Relatedness_start_entry = Relatedness_start[RCA_start < 1]\n",
    "Relative_relatedness_start_entry = Relative_relatedness_start[RCA_start < 1]\n",
    "\n",
    "RCA_start_exit = RCA_start[RCA_start >= 1]\n",
    "Relatedness_start_exit = Relatedness_start[RCA_start >= 1]\n",
    "Relative_relatedness_start_exit = Relative_relatedness_start[RCA_start >= 1]\n",
    "\n",
    "# Perform the calculations for Ycp_entry, Ycp_exit, and Ycp here\n",
    "Ycp_entry = np.exp((np.log(2) - (beta_country_entry[0] + beta_country_entry[2] * np.log(1 + RCA_start_entry) + beta_country_entry[3] * Relatedness_start_entry + beta_country_entry[4] * Relative_relatedness_start_entry)) / beta_country_entry[1]) - RCA_start_entry - 1\n",
    "Ycp_exit = np.exp((np.log(2) - (beta_country_exit[0] + beta_country_exit[2] * np.log(1 + RCA_start_exit) + beta_country_exit[3] * Relatedness_start_exit + beta_country_exit[4] * Relative_relatedness_start_exit)) / beta_country_exit[1]) - RCA_start_exit - 1\n",
    "\n",
    "Ycp = np.full(RCA_start.shape, np.nan)\n",
    "Ycp[RCA_start < 1] = Ycp_entry\n",
    "Ycp[RCA_start >= 1] = Ycp_exit\n",
    "\n",
    "ProductRankings['Ycp'] = Ycp\n",
    "\n",
    "indices_to_exclude = []\n",
    "\n",
    "ECI_initial = CountryRankings.loc[CountryRankings['Country'] == target_country, 'ECI_not_normalized'].values[0]\n",
    "ECI_initial_norm = CountryRankings.loc[CountryRankings['Country'] == target_country, 'ECI'].values[0]\n",
    "\n",
    "# Find the bin where the ECI falls\n",
    "matching_row = eci_pci_whisker_df[\n",
    "    (eci_pci_whisker_df['ECI_Lower'] <= ECI_initial_norm) &\n",
    "    (eci_pci_whisker_df['ECI_Upper'] > ECI_initial_norm)\n",
    "]\n",
    "\n",
    "# Extract the threshold\n",
    "if not matching_row.empty:\n",
    "    tresh = matching_row['PCI_UpperWhisker'].values[0]\n",
    "else:\n",
    "    tresh = np.nan  # fallback if no bin matches\n",
    "    \n",
    "# Example 1\n",
    "max_ECI_target = 0.05*sd_for_countries + ECI_initial\n",
    "df_eci = eciopt.eci_optimization(target_country, max_ECI_target, CountryRankings, ProductRankings, indices_to_exclude, beta_country_entry, beta_country_exit, PHIpp_country, tresh)\n",
    "ProductRankings['ECI_opt_1'] = df_eci['Added_vol'].copy()\n",
    "\n",
    "# Example 2\n",
    "max_ECI_target = 0.1*sd_for_countries + ECI_initial\n",
    "df_eci = eciopt.eci_optimization(target_country, max_ECI_target, CountryRankings, ProductRankings, indices_to_exclude, beta_country_entry, beta_country_exit, PHIpp_country, tresh)\n",
    "ProductRankings['ECI_opt_2'] = df_eci['Added_vol'].copy()\n",
    "\n",
    "# Example 3\n",
    "max_ECI_target = 0.2*sd_for_countries + ECI_initial\n",
    "df_eci = eciopt.eci_optimization(target_country, max_ECI_target, CountryRankings, ProductRankings, indices_to_exclude, beta_country_entry, beta_country_exit, PHIpp_country, tresh)\n",
    "ProductRankings['ECI_opt_3'] = df_eci['Added_vol'].copy()\n",
    "\n",
    "# --------------------------------------------------------------------------------\n",
    "# Helper function to plot each suggestion\n",
    "def plot_suggestion(ax, data, eci_opt_column, sd_multiplier, label):\n",
    "    \"\"\"\n",
    "    ax: Matplotlib Axes to plot on\n",
    "    data: DataFrame with columns: 'RCA', 'Ycp', 'PCI', 'percentile', eci_opt_column\n",
    "    eci_opt_column: name of column indicating which products are selected\n",
    "    sd_multiplier: multiplier for sd_for_countries to get your target ECI\n",
    "    label: subplot label, e.g. 'a', 'b', 'c'\n",
    "    \"\"\"\n",
    "    # Compute target ECI\n",
    "    ECI_initial = CountryRankings.loc[\n",
    "        CountryRankings['Country'] == target_country, 'ECI_not_normalized'\n",
    "    ].values[0]\n",
    "    ECI_initial = sd_multiplier * sd_for_countries + ECI_initial\n",
    "    ECI_val = (ECI_initial - mean_for_countries) / sd_for_countries\n",
    "\n",
    "    # Filter: only products with RCA < 1 (potential new entries)\n",
    "    cleared_data = data[data['RCA_start'] < 1].copy()\n",
    "\n",
    "    # Split into normal and highlight groups\n",
    "    normal_data = cleared_data[cleared_data[eci_opt_column] == 0].copy()\n",
    "    highlight_data = cleared_data[cleared_data[eci_opt_column] > 0].copy()\n",
    "\n",
    "    # Assign highlight colors\n",
    "    normal_data['highlight_color'] = [[200/255, 200/255, 200/255]] * len(normal_data)  # gray\n",
    "    highlight_data['highlight_color'] = [[0, 0.4470, 0.7410]] * len(highlight_data)   # blue\n",
    "\n",
    "    # Scatter for normal_data\n",
    "    ax.scatter(\n",
    "        normal_data['Ycp'],\n",
    "        normal_data['PCI'],\n",
    "        s=normal_data['percentile'] * 20,\n",
    "        c=normal_data['highlight_color'],\n",
    "        alpha=0.2\n",
    "    )\n",
    "\n",
    "    # Scatter for highlight_data\n",
    "    ax.scatter(\n",
    "        highlight_data['Ycp'],\n",
    "        highlight_data['PCI'],\n",
    "        s=highlight_data['percentile'] * 20,\n",
    "        c=highlight_data['highlight_color'],\n",
    "        edgecolor='black',\n",
    "        linewidth=0.5\n",
    "    )\n",
    "\n",
    "    # Draw a semi-transparent polygon around the \"highlight\" points via a Convex Hull\n",
    "    if len(highlight_data) > 2:\n",
    "        points = highlight_data[['Ycp','PCI']].dropna().values\n",
    "        hull = ConvexHull(points)\n",
    "        \n",
    "        # Get hull vertices in order, then repeat the first vertex to close the polygon\n",
    "        hull_vertices = np.append(hull.vertices, hull.vertices[0])\n",
    "        \n",
    "        ax.fill(\n",
    "            points[hull_vertices, 0],\n",
    "            points[hull_vertices, 1],\n",
    "            facecolor=[0, 0.4470, 0.7410, 0.15],  # fill: lighter blue with alpha\n",
    "            edgecolor=[0, 0.247, 0.541],         # outline: a darker shade of blue\n",
    "            linewidth=2\n",
    "        )\n",
    "\n",
    "    # Horizontal line for target ECI\n",
    "    ax.axhline(y=ECI_initial, color='black', linestyle='--', linewidth=2, alpha=0.4)\n",
    "\n",
    "    # Subplot labeling\n",
    "    ax.text(0, 1.15, label, transform=ax.transAxes, \n",
    "            fontsize=24, fontweight='normal', va='top', ha='left')\n",
    "    \n",
    "    # Title\n",
    "    country_name = country_titles.get(target_country, str(target_country))\n",
    "    ax.set_title(f\"{country_name} (Target ECI = {ECI_val:.2f})\", fontsize=16)\n",
    "\n",
    "    # Axis labels\n",
    "    ax.set_xlabel('Effort', fontsize=14)\n",
    "    ax.set_ylabel('Estimated PCI', fontsize=14)\n",
    "\n",
    "\n",
    "# Now set up figure with 3 subplots in one row\n",
    "fig = plt.figure(figsize=(20, 5))\n",
    "gs_top = fig.add_gridspec(nrows=1, ncols=3, wspace=0.3)\n",
    "ax1 = fig.add_subplot(gs_top[0, 0])\n",
    "ax2 = fig.add_subplot(gs_top[0, 1])\n",
    "ax3 = fig.add_subplot(gs_top[0, 2])\n",
    "\n",
    "plot_suggestion(ax1, ProductRankings, 'ECI_opt_1', 0.05, 'a')\n",
    "plot_suggestion(ax2, ProductRankings, 'ECI_opt_2', 0.10, 'b')\n",
    "plot_suggestion(ax3, ProductRankings, 'ECI_opt_3', 0.20, 'c')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('figure3.png', dpi=300)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44e1d18d",
   "metadata": {},
   "source": [
    "## FIGURE 4 - national level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b654aba",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LinearRegression\n",
    "from mpl_toolkits.axes_grid1.inset_locator import InsetPosition\n",
    "\n",
    "\n",
    "filtered_countries = CountryRankings.copy()\n",
    "\n",
    "# Initialize lists to store results\n",
    "final_RCA_columns = []\n",
    "final_RCA_columns_baseline = []\n",
    "final_RCA_columns_wp = []\n",
    "initial_RCA_starts = []\n",
    "initial_Relatedness_starts = []\n",
    "\n",
    "final_added_exports = []\n",
    "final_added_exports_baseline = []\n",
    "\n",
    "added_exports_sum = []\n",
    "added_exports_count = []\n",
    "avg_new_connections = []\n",
    "avg_std_connections = []\n",
    "avg_new_rca = []\n",
    "avg_new_Ycp = []\n",
    "rca_start_count = []\n",
    "new_rca_count = []\n",
    "\n",
    "added_exports_wp_sum = []\n",
    "added_exports_wp_count = []\n",
    "avg_new_connections_wp = []\n",
    "new_rca_wp_count = []\n",
    "\n",
    "added_exports_baseline_sum = []\n",
    "added_exports_baseline_count = []\n",
    "avg_new_connections_baseline = []\n",
    "avg_std_connections_baseline = []\n",
    "avg_new_rca_baseline = []\n",
    "avg_new_Ycp_baseline = []\n",
    "new_rca_baseline_count = []\n",
    "\n",
    "acc_final = []\n",
    "acc_final_baseline = []\n",
    "\n",
    "product_connections = PHIpp_country.sum(axis=1)\n",
    "\n",
    "total_countries = len(filtered_countries)\n",
    "\n",
    "\n",
    "for idx, target_country in enumerate(filtered_countries['Country'], start=1):\n",
    "    # Calculate the percentage completed\n",
    "    percentage_completed = (idx / total_countries) * 100\n",
    "    \n",
    "    # Print the target country and the percentage completed\n",
    "    print(f\"Processing country: {target_country} ({percentage_completed:.2f}% completed)\")\n",
    "\n",
    "    k = np.where(probit_country_data['countries_all'] == target_country)[0]\n",
    "\n",
    "    # Check if target_country is a number (int or float)\n",
    "    if isinstance(target_country, (int, float)):\n",
    "        X_start = X_countries[target_country - 1, :].T\n",
    "        M_start = M_countries[target_country - 1, :].T\n",
    "    else:\n",
    "        country_index = np.where(CountryRankings['Country'] == target_country)[0]\n",
    "        X_start = X_countries[country_index[0], :].T\n",
    "        M_start = M_countries[country_index[0], :].T\n",
    "        \n",
    "    X_c_start = np.sum(X_start)\n",
    "    X_p_start = np.sum(X_countries, axis=0).T\n",
    "    W_p = X_p_start / np.sum(X_p_start)\n",
    "\n",
    "    Relatedness_start = probit_country_data['Relatedness_start']\n",
    "    Relatedness_start = Relatedness_start[k]\n",
    "    Relatedness_start = Relatedness_start.to_numpy()\n",
    "\n",
    "    Relative_relatedness_start = probit_country_data['Relative_relatedness_start']\n",
    "    Relative_relatedness_start = Relative_relatedness_start[k]\n",
    "    Relative_relatedness_start = Relative_relatedness_start.to_numpy()\n",
    "    \n",
    "    predicted_prob = probit_country_data['predicted_prob']\n",
    "    predicted_prob = predicted_prob[k]\n",
    "    predicted_prob = predicted_prob.to_numpy()\n",
    "    \n",
    "    RCA_start = (X_start/X_c_start) / W_p\n",
    "\n",
    "    ProductRankings['X_start'] = X_start\n",
    "    ProductRankings['Relatedness_start'] = Relatedness_start\n",
    "    ProductRankings['Relative_relatedness_start'] = Relative_relatedness_start  \n",
    "    ProductRankings['predicted_prob'] = predicted_prob      \n",
    "    ProductRankings['M_start'] = M_start\n",
    "    ProductRankings['X_p_start'] = X_p_start\n",
    "    ProductRankings['W_p'] = W_p\n",
    "    ProductRankings['RCA_start'] = RCA_start\n",
    "    \n",
    "    \n",
    "    ProductRankings['relative_relatedness'] = np.where(ProductRankings['RCA_start'] >= 1, np.nan, ProductRankings['Relatedness_start'])\n",
    "\n",
    "    # Calculate the z-score for 'Relatedness_start', ignoring NaN values\n",
    "    ProductRankings['relative_relatedness'] = zscore(ProductRankings['relative_relatedness'], nan_policy='omit')\n",
    "\n",
    "    \n",
    "    indices_to_exclude = []\n",
    "\n",
    "    ECI_initial = CountryRankings.loc[CountryRankings['Country'] == target_country, 'ECI_not_normalized'].values[0]\n",
    "\n",
    "    ECI_initial_norm = CountryRankings.loc[CountryRankings['Country'] == target_country, 'ECI'].values[0]\n",
    "\n",
    "    # Find the bin where the ECI falls\n",
    "    matching_row = eci_pci_whisker_df[\n",
    "        (eci_pci_whisker_df['ECI_Lower'] <= ECI_initial_norm) &\n",
    "        (eci_pci_whisker_df['ECI_Upper'] > ECI_initial_norm)\n",
    "    ]\n",
    "    \n",
    "    # Extract the threshold\n",
    "    if not matching_row.empty:\n",
    "        tresh = matching_row['PCI_UpperWhisker'].values[0]\n",
    "    else:\n",
    "        tresh = np.nan  # fallback if no bin matches\n",
    "        \n",
    "    max_ECI_target = ECI_initial + 0.1\n",
    "    \n",
    "    # Determine the shape of the vectors beforehand\n",
    "    vector_shape = RCA_start.shape  # Assuming RCA_start is available before the conditional check\n",
    "\n",
    "    df_eci_country = eciopt.eci_optimization(target_country, max_ECI_target, CountryRankings, ProductRankings, indices_to_exclude, beta_country_entry, beta_country_exit, PHIpp_country,tresh)\n",
    "    \n",
    "    \n",
    "    Relatedness_start_baseline = (Relatedness_start - Relatedness_start.min()) / (Relatedness_start.max() - Relatedness_start.min())\n",
    "    pci = ProductRankings['PCI'].values\n",
    "    pci_baseline = (pci - pci.min()) / (pci.max() - pci.min())\n",
    "    baseline_vals = Relatedness_start_baseline * pci_baseline\n",
    "    \n",
    "    df_eci_baseline = eciopt.find_products_criteria(target_country, max_ECI_target, CountryRankings, ProductRankings, indices_to_exclude, beta_country_entry, beta_country_exit, baseline_vals)\n",
    "\n",
    "    \n",
    "    RCA_start_entry = RCA_start[RCA_start < 1]\n",
    "    Relatedness_start_entry = Relatedness_start[RCA_start < 1]\n",
    "    Relative_relatedness_start_entry = Relative_relatedness_start[RCA_start < 1]\n",
    "\n",
    "    RCA_start_exit = RCA_start[RCA_start >= 1]\n",
    "    Relatedness_start_exit = Relatedness_start[RCA_start >= 1]\n",
    "    Relative_relatedness_start_exit = Relative_relatedness_start[RCA_start >= 1]\n",
    "\n",
    "\n",
    "    Ycp_entry = np.exp((np.log(2)-(beta_country_entry[0] + beta_country_entry[2] * np.log(1+RCA_start_entry) + beta_country_entry[3] * Relatedness_start_entry + beta_country_entry[4] * Relative_relatedness_start_entry))/beta_country_entry[1]) - RCA_start_entry - 1\n",
    "    Ycp_exit = np.exp((np.log(2)-(beta_country_exit[0] + beta_country_exit[2] * np.log(1+RCA_start_exit) + beta_country_exit[3] * Relatedness_start_exit + beta_country_exit[4] * Relative_relatedness_start_exit))/beta_country_exit[1]) - RCA_start_exit - 1\n",
    "\n",
    "    Ycp = np.full(RCA_start.shape, np.nan)\n",
    "    Ycp[RCA_start < 1] = Ycp_entry\n",
    "    Ycp[RCA_start >= 1] = Ycp_exit\n",
    "    \n",
    "    # Save the final column of df_RCA_country and entries of RCA_start\n",
    "    final_added_exports.append(df_eci_country['Added_vol'])\n",
    "    final_added_exports_baseline.append(df_eci_baseline['Added_vol'])\n",
    "    final_RCA_columns.append(df_eci_country['RCA_final'])\n",
    "    final_RCA_columns_baseline.append(df_eci_baseline['RCA_final'])    \n",
    "    initial_RCA_starts.append(RCA_start)\n",
    "    initial_Relatedness_starts.append(Relatedness_start)\n",
    "\n",
    "    # Calculate the number of RCA_start > 1\n",
    "    num_rca_start_gt_1 = np.sum(RCA_start > 1)\n",
    "    rca_start_count.append(num_rca_start_gt_1)\n",
    "\n",
    "    # Calculate the new RCA that are > 1 but were not > 1 in RCA_start\n",
    "    final_RCA = df_eci_country['RCA_final'].to_numpy()\n",
    "    new_rca = np.sum((final_RCA > 1) & (RCA_start <= 1))\n",
    "    new_rca_count.append(new_rca)\n",
    "    added_exports_sum.append(np.sum(df_eci_country['Added_vol'])) \n",
    "    added_exports_count.append(np.sum(df_eci_country['Added_vol']>0))\n",
    "    mask = (final_RCA > 1) & (RCA_start <= 1)\n",
    "\n",
    "    # Calculate the mean\n",
    "    mean_relative_relatedness = np.sum(mask * ProductRankings['relative_relatedness']) / new_rca\n",
    "\n",
    "    # Calculate the standard deviation\n",
    "    std_relative_relatedness = np.std(ProductRankings['relative_relatedness'][mask])\n",
    "\n",
    "    # Append the standard deviation if needed\n",
    "    avg_new_connections.append(mean_relative_relatedness)\n",
    "    avg_std_connections.append(std_relative_relatedness)\n",
    "    avg_new_rca.append(np.sum(((final_RCA > 1) & (RCA_start <= 1)) * ProductRankings['RCA_start'])/new_rca)\n",
    "    avg_new_Ycp.append(np.sum((df_eci_country['Added_vol']> 0) * Ycp))\n",
    "\n",
    "    final_RCA_baseline = df_eci_baseline['RCA_final'].to_numpy()\n",
    "    new_rca_baseline = np.sum((final_RCA_baseline > 1) & (RCA_start <= 1))\n",
    "    new_rca_baseline_count.append(new_rca_baseline)\n",
    "    added_exports_baseline_sum.append(np.sum(df_eci_baseline['Added_vol'])) \n",
    "    added_exports_baseline_count.append(np.sum(df_eci_baseline['Added_vol']>0))\n",
    "    \n",
    "    mask = (final_RCA_baseline > 1) & (RCA_start <= 1)\n",
    "    \n",
    "    mean_relative_relatedness = np.sum(mask * ProductRankings['relative_relatedness']) / new_rca\n",
    "\n",
    "    # Calculate the standard deviation\n",
    "    std_relative_relatedness = np.std(ProductRankings['relative_relatedness'][mask])\n",
    "\n",
    "    # Append the standard deviation if needed\n",
    "    avg_new_connections_baseline.append(mean_relative_relatedness)\n",
    "    avg_std_connections_baseline.append(std_relative_relatedness)\n",
    "    \n",
    "    avg_new_rca_baseline.append(np.sum(((final_RCA_baseline > 1) & (RCA_start <= 1)) * ProductRankings['RCA_start'])/new_rca_baseline)\n",
    "    avg_new_Ycp_baseline.append(np.sum(((df_eci_baseline['Added_vol']> 0)) * Ycp))\n",
    "\n",
    "  \n",
    "        \n",
    "\n",
    "filtered_countries = CountryRankings.copy()\n",
    "\n",
    "# Add the results as new columns to filtered_countries DataFrame\n",
    "filtered_countries['Mc_pred'] = rca_start_count.copy()\n",
    "filtered_countries['delta_Mc_pred'] = new_rca_count.copy() \n",
    "filtered_countries['delta_opt_pred'] = added_exports_count.copy() \n",
    "filtered_countries['added_exp_pred'] = added_exports_sum.copy() \n",
    "filtered_countries['avg_new_connections_pred'] = avg_new_connections.copy()\n",
    "filtered_countries['avg_std_connections_pred'] = avg_std_connections.copy()\n",
    "filtered_countries['avg_new_rca_pred'] = avg_new_rca.copy()\n",
    "filtered_countries['avg_Ycp_pred'] = avg_new_Ycp.copy()\n",
    "\n",
    "\n",
    "filtered_countries['delta_Mc_baseline'] = new_rca_baseline_count.copy() \n",
    "filtered_countries['delta_opt_baseline'] = added_exports_baseline_count.copy() \n",
    "filtered_countries['added_exp_baseline'] = added_exports_baseline_sum.copy() \n",
    "filtered_countries['avg_new_connections_baseline'] = avg_new_connections_baseline.copy()\n",
    "filtered_countries['avg_std_connections_baseline'] = avg_std_connections_baseline.copy()\n",
    "filtered_countries['avg_new_rca_baseline'] = avg_new_rca_baseline.copy()\n",
    "filtered_countries['avg_Ycp_baseline'] = avg_new_Ycp_baseline.copy()\n",
    "\n",
    "\n",
    "#############################\n",
    "# 1) HELPER FUNCTIONS\n",
    "#############################\n",
    "def linear_fit(x, y):\n",
    "    \"\"\"\n",
    "    Fits a linear model and returns (model, predictions_for_unsorted_x).\n",
    "    \"\"\"\n",
    "    model = LinearRegression()\n",
    "    x_reshaped = np.array(x).reshape(-1, 1)\n",
    "    model.fit(x_reshaped, y)\n",
    "    y_pred = model.predict(x_reshaped)\n",
    "    return model, y_pred\n",
    "\n",
    "def poly_fit(x, y, degree=2):\n",
    "    \"\"\"\n",
    "    Fits a polynomial of given degree. Returns (poly_function, predictions_for_unsorted_x).\n",
    "    poly_function is a np.poly1d object, which you can call on any x-array.\n",
    "    \"\"\"\n",
    "    coeffs = np.polyfit(x, y, deg=degree)\n",
    "    p = np.poly1d(coeffs)\n",
    "    y_pred = p(x)  # predicted y on the *original* (unsorted) x\n",
    "    return p, y_pred\n",
    "\n",
    "#############################\n",
    "# 2) SCATTER + FIT LINES\n",
    "#############################\n",
    "\n",
    "# Colors and markers\n",
    "color_pred = [0, 0.4470, 0.7410]      # Blue\n",
    "color_baseline = [0.8500, 0.3250, 0.0980]   # Yellow\n",
    "marker_baseline = 's'\n",
    "marker_pred = 'o'\n",
    "\n",
    "fig, axes = plt.subplots(2, 2, figsize=(10, 8))\n",
    "\n",
    "scatter_kwargs_baseline = dict(marker=marker_baseline, color=color_baseline,\n",
    "                               s=30, edgecolor='black', linewidth=0.7, alpha=0.4)\n",
    "scatter_kwargs_pred = dict(marker=marker_pred, color=color_pred,\n",
    "                           s=40, edgecolor='black', linewidth=0.7, alpha=0.5)\n",
    "\n",
    "########################################\n",
    "# Plot (a): ECI vs avg_new_rca\n",
    "########################################\n",
    "x_a = filtered_countries['ECI_2022'].values\n",
    "y_a_baseline = filtered_countries['avg_new_rca_baseline'].values\n",
    "y_a_pred = filtered_countries['avg_new_rca_pred'].values\n",
    "\n",
    "# Scatter\n",
    "axes[0, 0].scatter(x_a, y_a_baseline, label='Benchmark', **scatter_kwargs_baseline)\n",
    "axes[0, 0].scatter(x_a, y_a_pred, label='ECI Optimization', **scatter_kwargs_pred)\n",
    "\n",
    "# Polynomial fits (degree=2, example)\n",
    "p_bl, y_a_bl_fit_unsorted = poly_fit(x_a, y_a_baseline, degree=2)\n",
    "p_pred, y_a_pred_fit_unsorted = poly_fit(x_a, y_a_pred, degree=2)\n",
    "\n",
    "# Sort x\n",
    "x_a_sorted = np.sort(x_a)\n",
    "# Predict with polynomial\n",
    "y_a_bl_fit_sorted = p_bl(x_a_sorted)\n",
    "y_a_pred_fit_sorted = p_pred(x_a_sorted)\n",
    "\n",
    "# Plot the lines\n",
    "axes[0, 0].plot(x_a_sorted, y_a_bl_fit_sorted, color=color_baseline,\n",
    "                linestyle='--', alpha=1, linewidth=2)  # Adjust the line width\n",
    "axes[0, 0].plot(x_a_sorted, y_a_pred_fit_sorted, color=color_pred,\n",
    "                linestyle='-', alpha=1, linewidth=2)  # Adjust the line width\n",
    "\n",
    "axes[0, 0].set_title('a', fontsize=16, loc='left')\n",
    "axes[0, 0].set_xlabel('ECI')\n",
    "axes[0, 0].set_ylabel('Average Current RCA\\nof New Products')\n",
    "\n",
    "########################################\n",
    "# Plot (b): ECI vs avg_new_connections\n",
    "########################################\n",
    "x_b = filtered_countries['ECI_2022'].values\n",
    "y_b_baseline = filtered_countries['avg_new_connections_baseline'].values\n",
    "y_b_pred = filtered_countries['avg_new_connections_pred'].values\n",
    "\n",
    "# Scatter\n",
    "axes[0, 1].scatter(x_b, y_b_baseline, label='Benchmark', **scatter_kwargs_baseline)\n",
    "axes[0, 1].scatter(x_b, y_b_pred, label='ECI Optimization', **scatter_kwargs_pred)\n",
    "\n",
    "# Polynomial fits (degree=2, example)\n",
    "p_bl, y_b_bl_fit_unsorted = poly_fit(x_b, y_b_baseline, degree=2)\n",
    "p_pred, y_b_pred_fit_unsorted = poly_fit(x_b, y_b_pred, degree=2)\n",
    "\n",
    "# Sort x\n",
    "x_b_sorted = np.sort(x_b)\n",
    "# Predict with polynomial\n",
    "y_b_bl_fit_sorted = p_bl(x_b_sorted)\n",
    "y_b_pred_fit_sorted = p_pred(x_b_sorted)\n",
    "\n",
    "# Plot lines\n",
    "axes[0, 1].plot(x_b_sorted, y_b_bl_fit_sorted, color=color_baseline,\n",
    "                linestyle='--', alpha=1, linewidth=2)\n",
    "axes[0, 1].plot(x_b_sorted, y_b_pred_fit_sorted, color=color_pred,\n",
    "                linestyle='-', alpha=1, linewidth=2)\n",
    "\n",
    "axes[0, 1].set_title('b', fontsize=16, loc='left')\n",
    "axes[0, 1].set_xlabel('ECI')\n",
    "axes[0, 1].set_ylabel('Average Relative\\nRelatedness of New Activities')\n",
    "\n",
    "########################################\n",
    "# Plot (c): Mc_pred vs delta_Mc\n",
    "########################################\n",
    "x_c = filtered_countries['Mc_pred'].values\n",
    "y_c_baseline = filtered_countries['delta_Mc_baseline'].values\n",
    "y_c_pred = filtered_countries['delta_Mc_pred'].values\n",
    "\n",
    "# Scatter\n",
    "axes[1, 0].scatter(x_c, y_c_baseline, label='Benchmark', **scatter_kwargs_baseline)\n",
    "axes[1, 0].scatter(x_c, y_c_pred, label='ECI Optimization', **scatter_kwargs_pred)\n",
    "\n",
    "# Linear fits\n",
    "model_bl_c, y_c_bl_fit_unsorted = linear_fit(x_c, y_c_baseline)\n",
    "model_pred_c, y_c_pred_fit_unsorted = linear_fit(x_c, y_c_pred)\n",
    "\n",
    "# Sort x\n",
    "x_c_sorted = np.sort(x_c)\n",
    "# Predict\n",
    "y_c_bl_fit_sorted = model_bl_c.predict(x_c_sorted.reshape(-1,1))\n",
    "y_c_pred_fit_sorted = model_pred_c.predict(x_c_sorted.reshape(-1,1))\n",
    "\n",
    "# Plot lines\n",
    "axes[1, 0].plot(x_c_sorted, y_c_bl_fit_sorted, color=color_baseline,\n",
    "                linestyle='--', alpha=1, linewidth=2)\n",
    "axes[1, 0].plot(x_c_sorted, y_c_pred_fit_sorted, color=color_pred,\n",
    "                linestyle='-', alpha=1, linewidth=2)\n",
    "\n",
    "axes[1, 0].set_title('c', fontsize=16, loc='left')\n",
    "axes[1, 0].set_xlabel('Diversity')\n",
    "axes[1, 0].set_ylabel('Number of New Activities')\n",
    "\n",
    "########################################\n",
    "# Plot (d): Mc_pred vs added_exp (log scale)\n",
    "########################################\n",
    "x_d = filtered_countries['Mc_pred'].values\n",
    "y_d_baseline = np.log(1+filtered_countries['added_exp_baseline'].values)\n",
    "y_d_pred = np.log(1+filtered_countries['added_exp_pred'].values)\n",
    "\n",
    "# Scatter\n",
    "axes[1, 1].scatter(x_d, np.exp(y_d_baseline)-1, label='Benchmark', **scatter_kwargs_baseline)\n",
    "axes[1, 1].scatter(x_d, np.exp(y_d_pred)-1, label='ECI Optimization', **scatter_kwargs_pred)\n",
    "\n",
    "# Polynomial fits (example)\n",
    "p_bl_d, y_d_bl_fit_unsorted = poly_fit(x_d, y_d_baseline, degree=1)\n",
    "p_pred_d, y_d_pred_fit_unsorted = poly_fit(x_d, y_d_pred, degree=1)\n",
    "\n",
    "# Sort x\n",
    "x_d_sorted = np.sort(x_d)\n",
    "# Predict\n",
    "y_d_bl_fit_sorted = np.exp(p_bl_d(x_d_sorted))-1\n",
    "y_d_pred_fit_sorted = np.exp(p_pred_d(x_d_sorted))-1\n",
    "\n",
    "# Plot lines\n",
    "axes[1, 1].plot(x_d_sorted, y_d_bl_fit_sorted, color=color_baseline,\n",
    "                linestyle='--', alpha=1, linewidth=3)\n",
    "axes[1, 1].plot(x_d_sorted, y_d_pred_fit_sorted, color=color_pred,\n",
    "                linestyle='-', alpha=1, linewidth=3)\n",
    "\n",
    "axes[1, 1].set_yscale('log')\n",
    "axes[1, 1].set_title('d', fontsize=16, loc='left')\n",
    "axes[1, 1].set_xlabel('Diversity')\n",
    "axes[1, 1].set_ylabel('Added Volume in USD')\n",
    "\n",
    "#############################\n",
    "# INSET AXIS FOR b\n",
    "#############################\n",
    "inset_ax = axes[0, 1].inset_axes([0.6, 0.65, 0.35, 0.3]) \n",
    "inset_ax.scatter(x_b, filtered_countries['avg_std_connections_baseline'],\n",
    "                 color=color_baseline, marker=marker_baseline, s=2, \n",
    "                 edgecolor='black', linewidth=0.7, alpha=0.4, label='Benchmark Std')\n",
    "inset_ax.scatter(x_b, filtered_countries['avg_std_connections_pred'],\n",
    "                 color=color_pred, marker=marker_pred, s=4, \n",
    "                 edgecolor='black', linewidth=0.7, alpha=0.5, label='ECI Opt Std')\n",
    "\n",
    "inset_ax.set_xlabel('ECI', fontsize=8)\n",
    "inset_ax.set_ylabel('St. Dev of\\nRelative Relatedness', fontsize=8)\n",
    "inset_ax.tick_params(axis='both', labelsize=8)\n",
    "\n",
    "x_b = filtered_countries['ECI_2022'].values\n",
    "y_b_baseline = filtered_countries['avg_std_connections_baseline'].values\n",
    "y_b_pred = filtered_countries['avg_std_connections_pred'].values\n",
    "\n",
    "# Polynomial fits (degree=2, example)\n",
    "p_bl, y_b_bl_fit_unsorted = poly_fit(x_b, y_b_baseline, degree=2)\n",
    "p_pred, y_b_pred_fit_unsorted = poly_fit(x_b, y_b_pred, degree=2)\n",
    "\n",
    "# Sort x\n",
    "x_b_sorted = np.sort(x_b)\n",
    "# Predict with polynomial\n",
    "y_b_bl_fit_sorted = p_bl(x_b_sorted)\n",
    "y_b_pred_fit_sorted = p_pred(x_b_sorted)\n",
    "\n",
    "# Plot lines\n",
    "inset_ax.plot(x_b_sorted, y_b_bl_fit_sorted, color=color_baseline,\n",
    "                linestyle='--', alpha=1, linewidth=2)\n",
    "inset_ax.plot(x_b_sorted, y_b_pred_fit_sorted, color=color_pred,\n",
    "                linestyle='-', alpha=1, linewidth=2)\n",
    "\n",
    "#############################\n",
    "# LEGENDS AND LAYOUT\n",
    "#############################\n",
    "fig.tight_layout(rect=[0, 0, 1, 0.96])\n",
    "handles, labels = axes[0, 0].get_legend_handles_labels()\n",
    "fig.legend(handles, labels, loc='upper center', ncol=3, fontsize=12)\n",
    "\n",
    "plt.show()\n",
    "fig.savefig(\"figure4.png\", dpi=300)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2fbf0fad",
   "metadata": {},
   "source": [
    "## FIGURE S16: Properties of ECI Optimization in United States MSA Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a38343b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Pivot the DataFrame to create the matrix\n",
    "predicted_msa_matrix = data_optimization_msa.pivot(index='msa_all', \n",
    "                                        columns='products_all', \n",
    "                                        values='predicted_prob')\n",
    "\n",
    "# Reindex the pivoted DataFrame to match the original order of msa and products\n",
    "predicted_msa_matrix = predicted_msa_matrix.reindex(index=msa_probit, columns=activities_probit)\n",
    "\n",
    "M_msa = np.array(predicted_msa_matrix > optimal_threshold, dtype=float)\n",
    "\n",
    "## similarity matrix\n",
    "\n",
    "# Ubiquity and diversity\n",
    "Kp0 = M_msa.sum(axis=0)\n",
    "Kc0 = M_msa.sum(axis=1)\n",
    "\n",
    "# Calculate proximity of products (PHIpp)\n",
    "PHIpp_msa = np.zeros((len(Kp0), len(Kp0)))\n",
    "for i in range(len(Kp0)):\n",
    "    for j in range(len(Kp0)):\n",
    "        PHIpp_msa[i, j] = np.dot(M_msa[:, i], M_msa[:, j]) / max(Kp0[i], Kp0[j])\n",
    "\n",
    "probit_msa_data = data_optimization_msa.copy\n",
    "\n",
    "msaRankings, NaicsRankings, _, _ = eciopt.cplex_rank(M_msa, msa_probit, activities_probit)\n",
    "\n",
    "msaRankings['ECI_2022'] = msaRankings_2022['ECI'].copy()\n",
    "msaRankings['ECI_not_normalized_2022'] = msaRankings_2022['ECI_not_normalized'].copy()\n",
    "\n",
    "NaicsRankings['PCI_2022'] = NaicsRankings['PCI'].copy()\n",
    "pci = NaicsRankings['PCI'].values\n",
    "\n",
    "\n",
    "# Step 1: Matrix-vector multiplication\n",
    "# Ensure M is a numpy array and pci is appropriately shaped for the multiplication\n",
    "product_vals = M_msa @ pci.reshape(-1, 1)  # pci reshaped to a column vector if not already\n",
    "\n",
    "# Step 2: Calculate row-wise sum of M\n",
    "row_sums = M_msa.sum(axis=1).reshape(-1, 1)  # Reshape for compatibility in division\n",
    "\n",
    "# Step 3: Element-wise division\n",
    "normalized_product = product_vals / row_sums\n",
    "\n",
    "# Since normalized_product might be a 2D array with a single column, flatten it if you're assigning back to a DataFrame column\n",
    "msaRankings['ECI_not_normalized'] = normalized_product.flatten()\n",
    "sd_for_msa = np.std(normalized_product)\n",
    "mean_for_msa = np.mean(normalized_product)\n",
    "\n",
    "\n",
    "# Filter msa \n",
    "filtered_msa = msaRankings.copy()\n",
    "\n",
    "# Initialize lists to store results\n",
    "final_RCA_columns = []\n",
    "final_RCA_columns_baseline = []\n",
    "final_RCA_columns_wp = []\n",
    "initial_RCA_starts = []\n",
    "initial_Relatedness_starts = []\n",
    "\n",
    "final_added_exports = []\n",
    "final_added_exports_baseline = []\n",
    "\n",
    "added_exports_sum = []\n",
    "added_exports_count = []\n",
    "avg_new_connections = []\n",
    "avg_std_connections = []\n",
    "avg_new_rca = []\n",
    "avg_new_Ycp = []\n",
    "rca_start_count = []\n",
    "new_rca_count = []\n",
    "\n",
    "added_exports_wp_sum = []\n",
    "added_exports_wp_count = []\n",
    "avg_new_connections_wp = []\n",
    "new_rca_wp_count = []\n",
    "\n",
    "added_exports_baseline_sum = []\n",
    "added_exports_baseline_count = []\n",
    "avg_new_connections_baseline = []\n",
    "avg_std_connections_baseline = []\n",
    "avg_new_rca_baseline = []\n",
    "avg_new_Ycp_baseline = []\n",
    "new_rca_baseline_count = []\n",
    "\n",
    "acc_final = []\n",
    "acc_final_baseline = []\n",
    "\n",
    "\n",
    "product_connections = PHIpp_msa.sum(axis=1)\n",
    "\n",
    "total_msa = len(filtered_msa)\n",
    "\n",
    "\n",
    "for idx, target_msa in enumerate(filtered_msa['Country'], start=1):\n",
    "    # Calculate the percentage completed\n",
    "    percentage_completed = (idx / total_msa) * 100\n",
    "    \n",
    "    # Print the target country and the percentage completed\n",
    "    print(f\"Processing country: {target_msa} ({percentage_completed:.2f}% completed)\")\n",
    "\n",
    "    k = np.where(data_optimization_msa['msa_all'] == target_msa)[0]\n",
    "\n",
    "    # Check if target_msa is a number (int or float)\n",
    "    msa_index = np.where(msaRankings['Country'] == target_msa)[0]\n",
    "    X_start = X_msa[msa_index[0], :].T\n",
    "    M_start = M_msa[msa_index[0], :].T\n",
    "\n",
    "        \n",
    "    X_c_start = np.sum(X_start)\n",
    "    X_p_start = np.sum(X_msa, axis=0).T\n",
    "    W_p = X_p_start / np.sum(X_p_start)\n",
    "\n",
    "    Relatedness_start = data_optimization_msa['Relatedness_start']\n",
    "    Relatedness_start = Relatedness_start[k]\n",
    "    Relatedness_start = Relatedness_start.to_numpy()\n",
    "\n",
    "    Relative_relatedness_start = data_optimization_msa['Relative_relatedness_start']\n",
    "    Relative_relatedness_start = Relative_relatedness_start[k]\n",
    "    Relative_relatedness_start = Relative_relatedness_start.to_numpy()\n",
    "    \n",
    "    predicted_prob = data_optimization_msa['predicted_prob']\n",
    "    predicted_prob = predicted_prob[k]\n",
    "    predicted_prob = predicted_prob.to_numpy()\n",
    "    \n",
    "    RCA_start = (X_start/X_c_start) / W_p\n",
    "\n",
    "    NaicsRankings['X_start'] = X_start\n",
    "    NaicsRankings['Relatedness_start'] = Relatedness_start\n",
    "    NaicsRankings['Relative_relatedness_start'] = Relative_relatedness_start  \n",
    "    NaicsRankings['predicted_prob'] = predicted_prob      \n",
    "    NaicsRankings['M_start'] = M_start\n",
    "    NaicsRankings['X_p_start'] = X_p_start\n",
    "    NaicsRankings['W_p'] = W_p\n",
    "    NaicsRankings['RCA_start'] = RCA_start\n",
    "    \n",
    "    NaicsRankings['relative_relatedness'] = np.where(NaicsRankings['RCA_start'] > 1, np.nan, NaicsRankings['Relatedness_start'])\n",
    "\n",
    "    # Calculate the z-score for 'Relatedness_start', ignoring NaN values\n",
    "    NaicsRankings['relative_relatedness'] = zscore(NaicsRankings['relative_relatedness'], nan_policy='omit')\n",
    "\n",
    "    \n",
    "    indices_to_exclude = []\n",
    "\n",
    "    ECI_initial = msaRankings.loc[msaRankings['Country'] == target_msa, 'ECI_not_normalized'].values[0]\n",
    "    ECI_initial_norm = msaRankings.loc[msaRankings['Country'] == target_msa, 'ECI'].values[0]\n",
    "\n",
    "    # Find the bin where the ECI falls\n",
    "    matching_row = msa_eci_pci_whisker_df[\n",
    "        (msa_eci_pci_whisker_df['ECI_Lower'] <= ECI_initial_norm) &\n",
    "        (msa_eci_pci_whisker_df['ECI_Upper'] > ECI_initial_norm)\n",
    "    ]\n",
    "    \n",
    "    # Extract the threshold\n",
    "    if not matching_row.empty:\n",
    "        tresh = matching_row['PCI_UpperWhisker'].values[0]\n",
    "    else:\n",
    "        tresh = np.nan  # fallback if no bin matches\n",
    "\n",
    "    if tresh > 1:\n",
    "        tresh = 10\n",
    "    \n",
    "\n",
    "    max_ECI_target = ECI_initial + 0.1\n",
    "        \n",
    "    # Determine the shape of the vectors beforehand\n",
    "\n",
    "\n",
    "    df_eci = eciopt.eci_optimization(target_msa, max_ECI_target, msaRankings, NaicsRankings, indices_to_exclude, beta_msa_entry, beta_msa_exit, PHIpp_msa,tresh)\n",
    "\n",
    "    Relatedness_start_baseline = (Relatedness_start - Relatedness_start.min()) / (Relatedness_start.max() - Relatedness_start.min())\n",
    "    pci = NaicsRankings['PCI'].values\n",
    "    pci_baseline = (pci - pci.min()) / (pci.max() - pci.min())\n",
    "    baseline_vals = Relatedness_start_baseline * pci_baseline\n",
    "    \n",
    "    df_eci_baseline = eciopt.find_products_criteria(target_msa, max_ECI_target, msaRankings, NaicsRankings, indices_to_exclude, beta_msa_entry, beta_msa_exit, baseline_vals)\n",
    " \n",
    "\n",
    "    RCA_start_entry = RCA_start[RCA_start < 1]\n",
    "    Relatedness_start_entry = Relatedness_start[RCA_start < 1]\n",
    "    Relative_relatedness_start_entry = Relative_relatedness_start[RCA_start < 1]\n",
    "\n",
    "    RCA_start_exit = RCA_start[RCA_start >= 1]\n",
    "    Relatedness_start_exit = Relatedness_start[RCA_start >= 1]\n",
    "    Relative_relatedness_start_exit = Relative_relatedness_start[RCA_start >= 1]\n",
    "\n",
    "\n",
    "    Ycp_entry = np.exp((np.log(2)-(beta_msa_entry[0] + beta_msa_entry[2] * np.log(1+RCA_start_entry) + beta_msa_entry[3] * Relatedness_start_entry + beta_msa_entry[4] * Relative_relatedness_start_entry))/beta_msa_entry[1]) - RCA_start_entry - 1\n",
    "    Ycp_exit = np.exp((np.log(2)-(beta_msa_exit[0] + beta_msa_exit[2] * np.log(1+RCA_start_exit) + beta_msa_exit[3] * Relatedness_start_exit + beta_msa_exit[4] * Relative_relatedness_start_exit))/beta_msa_exit[1]) - RCA_start_exit - 1\n",
    "\n",
    "    Ycp = np.full(RCA_start.shape, np.nan)\n",
    "    Ycp[RCA_start < 1] = Ycp_entry\n",
    "    Ycp[RCA_start >= 1] = Ycp_exit\n",
    "\n",
    "    # Save the final column of df_RCA_country and entries of RCA_start\n",
    "    final_added_exports.append(df_eci['Added_vol'])\n",
    "    final_added_exports_baseline.append(df_eci_baseline['Added_vol'])\n",
    "    final_RCA_columns.append(df_eci['RCA_final'])\n",
    "    final_RCA_columns_baseline.append(df_eci_baseline['RCA_final'])    \n",
    "    initial_RCA_starts.append(RCA_start)\n",
    "    initial_Relatedness_starts.append(Relatedness_start)\n",
    "\n",
    "    # Calculate the number of RCA_start > 1\n",
    "    num_rca_start_gt_1 = np.sum(RCA_start > 1)\n",
    "    rca_start_count.append(num_rca_start_gt_1)\n",
    "\n",
    "    # Calculate the new RCA that are > 1 but were not > 1 in RCA_start\n",
    "    final_RCA = df_eci['RCA_final'].to_numpy()\n",
    "    new_rca = np.sum((final_RCA > 1) & (RCA_start <= 1))\n",
    "    new_rca_count.append(new_rca)\n",
    "    added_exports_sum.append(np.sum(df_eci['Added_vol'])) \n",
    "    added_exports_count.append(np.sum(df_eci['Added_vol']>0))\n",
    "    mask = (final_RCA > 1) & (RCA_start <= 1)\n",
    "\n",
    "    # Calculate the mean\n",
    "    mean_relative_relatedness = np.sum(mask * NaicsRankings['relative_relatedness']) / new_rca\n",
    "\n",
    "    # Calculate the standard deviation\n",
    "    std_relative_relatedness = np.std(NaicsRankings['relative_relatedness'][mask])\n",
    "\n",
    "    # Append the standard deviation if needed\n",
    "    avg_new_connections.append(mean_relative_relatedness)\n",
    "    avg_std_connections.append(std_relative_relatedness)\n",
    "    \n",
    "    avg_new_rca.append(np.sum((df_eci['Added_vol']> 0) * NaicsRankings['RCA_start'])/new_rca)\n",
    "    avg_new_Ycp.append(np.sum((df_eci['Added_vol']> 0) * Ycp))\n",
    "\n",
    "\n",
    "\n",
    "    final_RCA_baseline = df_eci_baseline['RCA_final'].to_numpy()\n",
    "    new_rca_baseline = np.sum((final_RCA_baseline > 1) & (RCA_start <= 1))\n",
    "    new_rca_baseline_count.append(new_rca_baseline)\n",
    "    added_exports_baseline_sum.append(np.sum(df_eci_baseline['Added_vol'])) \n",
    "    added_exports_baseline_count.append(np.sum(df_eci_baseline['Added_vol']>0))\n",
    "    mask = (final_RCA_baseline > 1) & (RCA_start <= 1)\n",
    "    \n",
    "    mean_relative_relatedness = np.sum(mask * NaicsRankings['relative_relatedness']) / new_rca\n",
    "\n",
    "    # Calculate the standard deviation\n",
    "    std_relative_relatedness = np.std(NaicsRankings['relative_relatedness'][mask])\n",
    "\n",
    "    # Append the standard deviation if needed\n",
    "    avg_new_connections_baseline.append(mean_relative_relatedness)\n",
    "    avg_std_connections_baseline.append(std_relative_relatedness)\n",
    "\n",
    "    \n",
    "    avg_new_rca_baseline.append(np.sum((df_eci_baseline['Added_vol']> 0) * NaicsRankings['RCA_start'])/new_rca_baseline)\n",
    "    avg_new_Ycp_baseline.append(np.sum(((df_eci_baseline['Added_vol']> 0)) * Ycp))\n",
    "\n",
    "\n",
    "filtered_msa = msaRankings.copy()\n",
    "\n",
    "# Add the results as new columns to filtered_msa DataFrame\n",
    "filtered_msa['Mc_pred'] = rca_start_count.copy()\n",
    "filtered_msa['delta_Mc_pred'] = new_rca_count.copy() \n",
    "filtered_msa['delta_opt_pred'] = added_exports_count.copy() \n",
    "filtered_msa['added_exp_pred'] = added_exports_sum.copy() \n",
    "filtered_msa['avg_new_connections_pred'] = avg_new_connections.copy()\n",
    "filtered_msa['avg_std_connections_pred'] = avg_std_connections.copy()\n",
    "filtered_msa['avg_new_rca_pred'] = avg_new_rca.copy()\n",
    "filtered_msa['avg_new_Ycp_pred'] = avg_new_Ycp.copy()\n",
    "\n",
    "\n",
    "filtered_msa['delta_Mc_baseline'] = new_rca_baseline_count.copy() \n",
    "filtered_msa['delta_opt_baseline'] = added_exports_baseline_count.copy() \n",
    "filtered_msa['added_exp_baseline'] = added_exports_baseline_sum.copy() \n",
    "filtered_msa['avg_new_connections_baseline'] = avg_new_connections_baseline.copy()\n",
    "filtered_msa['avg_std_connections_baseline'] = avg_std_connections_baseline.copy()\n",
    "filtered_msa['avg_new_rca_baseline'] = avg_new_rca_baseline.copy()\n",
    "filtered_msa['avg_new_Ycp_baseline'] = avg_new_Ycp_baseline.copy()\n",
    "\n",
    "## PLOT FIGURE \n",
    "\n",
    "#############################\n",
    "# 1) HELPER FUNCTIONS\n",
    "#############################\n",
    "def linear_fit(x, y):\n",
    "    \"\"\"\n",
    "    Fits a linear model and returns (model, predictions_for_unsorted_x).\n",
    "    \"\"\"\n",
    "    model = LinearRegression()\n",
    "    x_reshaped = np.array(x).reshape(-1, 1)\n",
    "    model.fit(x_reshaped, y)\n",
    "    y_pred = model.predict(x_reshaped)\n",
    "    return model, y_pred\n",
    "\n",
    "def poly_fit(x, y, degree=2):\n",
    "    \"\"\"\n",
    "    Fits a polynomial of given degree. Returns (poly_function, predictions_for_unsorted_x).\n",
    "    poly_function is a np.poly1d object, which you can call on any x-array.\n",
    "    \"\"\"\n",
    "    coeffs = np.polyfit(x, y, deg=degree)\n",
    "    p = np.poly1d(coeffs)\n",
    "    y_pred = p(x)  # predicted y on the *original* (unsorted) x\n",
    "    return p, y_pred\n",
    "\n",
    "#############################\n",
    "# 2) SCATTER + FIT LINES\n",
    "#############################\n",
    "\n",
    "# Colors and markers\n",
    "color_pred = [0, 0.4470, 0.7410]      # Blue\n",
    "color_baseline = [0.8500, 0.3250, 0.0980]   # Yellow\n",
    "marker_baseline = 's'\n",
    "marker_pred = 'o'\n",
    "\n",
    "fig, axes = plt.subplots(2, 2, figsize=(10, 8))\n",
    "\n",
    "scatter_kwargs_baseline = dict(marker=marker_baseline, color=color_baseline,\n",
    "                               s=30, edgecolor='black', linewidth=0.7, alpha=0.4)\n",
    "scatter_kwargs_pred = dict(marker=marker_pred, color=color_pred,\n",
    "                           s=40, edgecolor='black', linewidth=0.7, alpha=0.5)\n",
    "\n",
    "########################################\n",
    "# Plot (a): ECI vs avg_new_rca\n",
    "########################################\n",
    "x_a = filtered_msa['ECI_2022'].values\n",
    "y_a_baseline = filtered_msa['avg_new_rca_baseline'].values\n",
    "y_a_pred = filtered_msa['avg_new_rca_pred'].values\n",
    "\n",
    "# Scatter\n",
    "axes[0, 0].scatter(x_a, y_a_baseline, label='Benchmark', **scatter_kwargs_baseline)\n",
    "axes[0, 0].scatter(x_a, y_a_pred, label='ECI Optimization', **scatter_kwargs_pred)\n",
    "\n",
    "# Polynomial fits (degree=2, example)\n",
    "p_bl, y_a_bl_fit_unsorted = poly_fit(x_a, y_a_baseline, degree=2)\n",
    "p_pred, y_a_pred_fit_unsorted = poly_fit(x_a, y_a_pred, degree=2)\n",
    "\n",
    "# Sort x\n",
    "x_a_sorted = np.sort(x_a)\n",
    "# Predict with polynomial\n",
    "y_a_bl_fit_sorted = p_bl(x_a_sorted)\n",
    "y_a_pred_fit_sorted = p_pred(x_a_sorted)\n",
    "\n",
    "# Plot the lines\n",
    "axes[0, 0].plot(x_a_sorted, y_a_bl_fit_sorted, color=color_baseline,\n",
    "                linestyle='--', alpha=1, linewidth=2)  # Adjust the line width\n",
    "axes[0, 0].plot(x_a_sorted, y_a_pred_fit_sorted, color=color_pred,\n",
    "                linestyle='-', alpha=1, linewidth=2)  # Adjust the line width\n",
    "\n",
    "axes[0, 0].set_title('a', fontsize=16, loc='left')\n",
    "axes[0, 0].set_xlabel('ECI')\n",
    "axes[0, 0].set_ylabel('Average Current RCA\\nof New Products')\n",
    "\n",
    "########################################\n",
    "# Plot (b): ECI vs avg_new_connections\n",
    "########################################\n",
    "x_b = filtered_msa['ECI_2022'].values\n",
    "y_b_baseline = filtered_msa['avg_new_connections_baseline'].values\n",
    "y_b_pred = filtered_msa['avg_new_connections_pred'].values\n",
    "\n",
    "# Scatter\n",
    "axes[0, 1].scatter(x_b, y_b_baseline, label='Benchmark', **scatter_kwargs_baseline)\n",
    "axes[0, 1].scatter(x_b, y_b_pred, label='ECI Optimization', **scatter_kwargs_pred)\n",
    "\n",
    "# Polynomial fits (degree=2, example)\n",
    "p_bl, y_b_bl_fit_unsorted = poly_fit(x_b, y_b_baseline, degree=2)\n",
    "p_pred, y_b_pred_fit_unsorted = poly_fit(x_b, y_b_pred, degree=2)\n",
    "\n",
    "# Sort x\n",
    "x_b_sorted = np.sort(x_b)\n",
    "# Predict with polynomial\n",
    "y_b_bl_fit_sorted = p_bl(x_b_sorted)\n",
    "y_b_pred_fit_sorted = p_pred(x_b_sorted)\n",
    "\n",
    "# Plot lines\n",
    "axes[0, 1].plot(x_b_sorted, y_b_bl_fit_sorted, color=color_baseline,\n",
    "                linestyle='--', alpha=1, linewidth=2)\n",
    "axes[0, 1].plot(x_b_sorted, y_b_pred_fit_sorted, color=color_pred,\n",
    "                linestyle='-', alpha=1, linewidth=2)\n",
    "\n",
    "axes[0, 1].set_title('b', fontsize=16, loc='left')\n",
    "axes[0, 1].set_xlabel('ECI')\n",
    "axes[0, 1].set_ylabel('Average Relative\\nRelatedness of New Activities')\n",
    "\n",
    "########################################\n",
    "# Plot (c): Mc_pred vs delta_Mc\n",
    "########################################\n",
    "x_c = filtered_msa['Mc_pred'].values\n",
    "y_c_baseline = filtered_msa['delta_Mc_baseline'].values\n",
    "y_c_pred = filtered_msa['delta_Mc_pred'].values\n",
    "\n",
    "# Scatter\n",
    "axes[1, 0].scatter(x_c, y_c_baseline, label='Benchmark', **scatter_kwargs_baseline)\n",
    "axes[1, 0].scatter(x_c, y_c_pred, label='ECI Optimization', **scatter_kwargs_pred)\n",
    "\n",
    "# Linear fits\n",
    "model_bl_c, y_c_bl_fit_unsorted = linear_fit(x_c, y_c_baseline)\n",
    "model_pred_c, y_c_pred_fit_unsorted = linear_fit(x_c, y_c_pred)\n",
    "\n",
    "# Sort x\n",
    "x_c_sorted = np.sort(x_c)\n",
    "# Predict\n",
    "y_c_bl_fit_sorted = model_bl_c.predict(x_c_sorted.reshape(-1,1))\n",
    "y_c_pred_fit_sorted = model_pred_c.predict(x_c_sorted.reshape(-1,1))\n",
    "\n",
    "# Plot lines\n",
    "axes[1, 0].plot(x_c_sorted, y_c_bl_fit_sorted, color=color_baseline,\n",
    "                linestyle='--', alpha=1, linewidth=2)\n",
    "axes[1, 0].plot(x_c_sorted, y_c_pred_fit_sorted, color=color_pred,\n",
    "                linestyle='-', alpha=1, linewidth=2)\n",
    "\n",
    "axes[1, 0].set_title('c', fontsize=16, loc='left')\n",
    "axes[1, 0].set_xlabel('Diversity')\n",
    "axes[1, 0].set_ylabel('Number of New Activities')\n",
    "\n",
    "########################################\n",
    "# Plot (d): Mc_pred vs added_exp (log scale)\n",
    "########################################\n",
    "x_d = filtered_msa['Mc_pred'].values\n",
    "y_d_baseline = np.log(1+filtered_msa['added_exp_baseline'].values)\n",
    "y_d_pred = np.log(1+filtered_msa['added_exp_pred'].values)\n",
    "\n",
    "# Scatter\n",
    "axes[1, 1].scatter(x_d, np.exp(y_d_baseline)-1, label='Benchmark', **scatter_kwargs_baseline)\n",
    "axes[1, 1].scatter(x_d, np.exp(y_d_pred)-1, label='ECI Optimization', **scatter_kwargs_pred)\n",
    "\n",
    "# Polynomial fits (example)\n",
    "p_bl_d, y_d_bl_fit_unsorted = poly_fit(x_d, y_d_baseline, degree=1)\n",
    "p_pred_d, y_d_pred_fit_unsorted = poly_fit(x_d, y_d_pred, degree=1)\n",
    "\n",
    "# Sort x\n",
    "x_d_sorted = np.sort(x_d)\n",
    "# Predict\n",
    "y_d_bl_fit_sorted = np.exp(p_bl_d(x_d_sorted))-1\n",
    "y_d_pred_fit_sorted = np.exp(p_pred_d(x_d_sorted))-1\n",
    "\n",
    "# Plot lines\n",
    "axes[1, 1].plot(x_d_sorted, y_d_bl_fit_sorted, color=color_baseline,\n",
    "                linestyle='--', alpha=1, linewidth=3)\n",
    "axes[1, 1].plot(x_d_sorted, y_d_pred_fit_sorted, color=color_pred,\n",
    "                linestyle='-', alpha=1, linewidth=3)\n",
    "\n",
    "axes[1, 1].set_yscale('log')\n",
    "axes[1, 1].set_title('d', fontsize=16, loc='left')\n",
    "axes[1, 1].set_xlabel('Diversity')\n",
    "axes[1, 1].set_ylabel('Added Volume in USD')\n",
    "\n",
    "#############################\n",
    "# INSET AXIS FOR b\n",
    "#############################\n",
    "inset_ax = axes[0, 1].inset_axes([0.6, 0.65, 0.35, 0.3]) \n",
    "inset_ax.scatter(x_b, filtered_msa['avg_std_connections_baseline'],\n",
    "                 color=color_baseline, marker=marker_baseline, s=2, \n",
    "                 edgecolor='black', linewidth=0.7, alpha=0.4, label='Benchmark Std')\n",
    "inset_ax.scatter(x_b, filtered_msa['avg_std_connections_pred'],\n",
    "                 color=color_pred, marker=marker_pred, s=4, \n",
    "                 edgecolor='black', linewidth=0.7, alpha=0.5, label='ECI Opt Std')\n",
    "\n",
    "inset_ax.set_xlabel('ECI', fontsize=8)\n",
    "inset_ax.set_ylabel('St. Dev of\\nRelative Relatedness', fontsize=8)\n",
    "inset_ax.tick_params(axis='both', labelsize=8)\n",
    "\n",
    "x_b = filtered_msa['ECI_2022'].values\n",
    "y_b_baseline = filtered_msa['avg_std_connections_baseline'].values\n",
    "y_b_pred = filtered_msa['avg_std_connections_pred'].values\n",
    "\n",
    "# Polynomial fits (degree=2, example)\n",
    "p_bl, y_b_bl_fit_unsorted = poly_fit(x_b, y_b_baseline, degree=2)\n",
    "p_pred, y_b_pred_fit_unsorted = poly_fit(x_b, y_b_pred, degree=2)\n",
    "\n",
    "# Sort x\n",
    "x_b_sorted = np.sort(x_b)\n",
    "# Predict with polynomial\n",
    "y_b_bl_fit_sorted = p_bl(x_b_sorted)\n",
    "y_b_pred_fit_sorted = p_pred(x_b_sorted)\n",
    "\n",
    "# Plot lines\n",
    "inset_ax.plot(x_b_sorted, y_b_bl_fit_sorted, color=color_baseline,\n",
    "                linestyle='--', alpha=1, linewidth=2)\n",
    "inset_ax.plot(x_b_sorted, y_b_pred_fit_sorted, color=color_pred,\n",
    "                linestyle='-', alpha=1, linewidth=2)\n",
    "\n",
    "#############################\n",
    "# LEGENDS AND LAYOUT\n",
    "#############################\n",
    "fig.tight_layout(rect=[0, 0, 1, 0.96])\n",
    "handles, labels = axes[0, 0].get_legend_handles_labels()\n",
    "fig.legend(handles, labels, loc='upper center', ncol=3, fontsize=12)\n",
    "\n",
    "plt.show()\n",
    "fig.savefig(\"figure_s16.png\", dpi=300)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3a324bc-c531-4718-94fb-26f900abeb3c",
   "metadata": {},
   "source": [
    "## FIGURE S16: Properties of ECI Optimization in Patents Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8144c227-f8b8-41d2-a464-c5bb04da6cdd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Pivot the DataFrame to create the matrix\n",
    "predicted_pct_matrix = data_optimization_pct.pivot(index='pct_all', \n",
    "                                        columns='products_all', \n",
    "                                        values='predicted_prob')\n",
    "\n",
    "# Reindex the pivoted DataFrame to match the original order of pct and products\n",
    "predicted_pct_matrix = predicted_pct_matrix.reindex(index=pct_probit, columns=patents_probit)\n",
    "\n",
    "M_pct = np.array(predicted_pct_matrix > optimal_threshold, dtype=float)\n",
    "\n",
    "## similarity matrix\n",
    "\n",
    "# Ubiquity and diversity\n",
    "Kp0 = M_pct.sum(axis=0)\n",
    "Kc0 = M_pct.sum(axis=1)\n",
    "\n",
    "# Calculate proximity of products (PHIpp)\n",
    "PHIpp_pct = np.zeros((len(Kp0), len(Kp0)))\n",
    "for i in range(len(Kp0)):\n",
    "    for j in range(len(Kp0)):\n",
    "        PHIpp_pct[i, j] = np.dot(M_pct[:, i], M_pct[:, j]) / max(Kp0[i], Kp0[j])\n",
    "\n",
    "probit_pct_data = data_optimization_pct.copy\n",
    "\n",
    "pctRankings, PCTRankings, _, _ = eciopt.cplex_rank(M_pct, pct_probit, patents_probit)\n",
    "\n",
    "pctRankings['ECI_2022'] = pctRankings_2022['ECI'].copy()\n",
    "pctRankings['ECI_not_normalized_2022'] = pctRankings_2022['ECI_not_normalized'].copy()\n",
    "\n",
    "PCTRankings['PCI_2022'] = PCTRankings['PCI'].copy()\n",
    "pci = PCTRankings['PCI'].values\n",
    "\n",
    "\n",
    "# Step 1: Matrix-vector multiplication\n",
    "# Ensure M is a numpy array and pci is appropriately shaped for the multiplication\n",
    "product_vals = M_pct @ pci.reshape(-1, 1)  # pci reshaped to a column vector if not already\n",
    "\n",
    "# Step 2: Calculate row-wise sum of M\n",
    "row_sums = M_pct.sum(axis=1).reshape(-1, 1)  # Reshape for compatibility in division\n",
    "\n",
    "# Step 3: Element-wise division\n",
    "normalized_product = product_vals / row_sums\n",
    "\n",
    "# Since normalized_product might be a 2D array with a single column, flatten it if you're assigning back to a DataFrame column\n",
    "pctRankings['ECI_not_normalized'] = normalized_product.flatten()\n",
    "sd_for_pct = np.std(normalized_product)\n",
    "mean_for_pct = np.mean(normalized_product)\n",
    "\n",
    "\n",
    "# Filter pct \n",
    "filtered_pct = pctRankings.copy()\n",
    "\n",
    "# Initialize lists to store results\n",
    "final_RCA_columns = []\n",
    "final_RCA_columns_baseline = []\n",
    "final_RCA_columns_wp = []\n",
    "initial_RCA_starts = []\n",
    "initial_Relatedness_starts = []\n",
    "\n",
    "final_added_exports = []\n",
    "final_added_exports_baseline = []\n",
    "\n",
    "added_exports_sum = []\n",
    "added_exports_count = []\n",
    "avg_new_connections = []\n",
    "avg_std_connections = []\n",
    "avg_new_rca = []\n",
    "avg_new_Ycp = []\n",
    "rca_start_count = []\n",
    "new_rca_count = []\n",
    "\n",
    "added_exports_wp_sum = []\n",
    "added_exports_wp_count = []\n",
    "avg_new_connections_wp = []\n",
    "new_rca_wp_count = []\n",
    "\n",
    "added_exports_baseline_sum = []\n",
    "added_exports_baseline_count = []\n",
    "avg_new_connections_baseline = []\n",
    "avg_std_connections_baseline = []\n",
    "avg_new_rca_baseline = []\n",
    "avg_new_Ycp_baseline = []\n",
    "new_rca_baseline_count = []\n",
    "\n",
    "acc_final = []\n",
    "acc_final_baseline = []\n",
    "\n",
    "\n",
    "product_connections = PHIpp_pct.sum(axis=1)\n",
    "\n",
    "total_pct = len(filtered_pct)\n",
    "\n",
    "\n",
    "for idx, target_pct in enumerate(filtered_pct['Country'], start=1):\n",
    "    # Calculate the percentage completed\n",
    "    percentage_completed = (idx / total_pct) * 100\n",
    "    \n",
    "    # Print the target country and the percentage completed\n",
    "    print(f\"Processing country: {target_pct} ({percentage_completed:.2f}% completed)\")\n",
    "\n",
    "    k = np.where(data_optimization_pct['pct_all'] == target_pct)[0]\n",
    "\n",
    "    # Check if target_pct is a number (int or float)\n",
    "    pct_index = np.where(pctRankings['Country'] == target_pct)[0]\n",
    "    X_start = X_pct[pct_index[0], :].T\n",
    "    M_start = M_pct[pct_index[0], :].T\n",
    "\n",
    "        \n",
    "    X_c_start = np.sum(X_start)\n",
    "    X_p_start = np.sum(X_pct, axis=0).T\n",
    "    W_p = X_p_start / np.sum(X_p_start)\n",
    "\n",
    "    Relatedness_start = data_optimization_pct['Relatedness_start']\n",
    "    Relatedness_start = Relatedness_start[k]\n",
    "    Relatedness_start = Relatedness_start.to_numpy()\n",
    "\n",
    "    Relative_relatedness_start = data_optimization_pct['Relative_relatedness_start']\n",
    "    Relative_relatedness_start = Relative_relatedness_start[k]\n",
    "    Relative_relatedness_start = Relative_relatedness_start.to_numpy()\n",
    "    \n",
    "    predicted_prob = data_optimization_pct['predicted_prob']\n",
    "    predicted_prob = predicted_prob[k]\n",
    "    predicted_prob = predicted_prob.to_numpy()\n",
    "    \n",
    "    RCA_start = (X_start/X_c_start) / W_p\n",
    "\n",
    "    PCTRankings['X_start'] = X_start\n",
    "    PCTRankings['Relatedness_start'] = Relatedness_start\n",
    "    PCTRankings['Relative_relatedness_start'] = Relative_relatedness_start  \n",
    "    PCTRankings['predicted_prob'] = predicted_prob      \n",
    "    PCTRankings['M_start'] = M_start\n",
    "    PCTRankings['X_p_start'] = X_p_start\n",
    "    PCTRankings['W_p'] = W_p\n",
    "    PCTRankings['RCA_start'] = RCA_start\n",
    "    \n",
    "    PCTRankings['relative_relatedness'] = np.where(PCTRankings['RCA_start'] > 1, np.nan, PCTRankings['Relatedness_start'])\n",
    "\n",
    "    # Calculate the z-score for 'Relatedness_start', ignoring NaN values\n",
    "    PCTRankings['relative_relatedness'] = zscore(PCTRankings['relative_relatedness'], nan_policy='omit')\n",
    "\n",
    "    \n",
    "    indices_to_exclude = []\n",
    "\n",
    "    ECI_initial = pctRankings.loc[pctRankings['Country'] == target_pct, 'ECI_not_normalized'].values[0]\n",
    "    ECI_initial_norm = pctRankings.loc[pctRankings['Country'] == target_pct, 'ECI'].values[0]\n",
    "\n",
    "    # Find the bin where the ECI falls\n",
    "    matching_row = pct_eci_pci_whisker_df[\n",
    "        (pct_eci_pci_whisker_df['ECI_Lower'] <= ECI_initial_norm) &\n",
    "        (pct_eci_pci_whisker_df['ECI_Upper'] > ECI_initial_norm)\n",
    "    ]\n",
    "    \n",
    "    # Extract the threshold\n",
    "    if not matching_row.empty:\n",
    "        tresh = matching_row['PCI_UpperWhisker'].values[0]\n",
    "    else:\n",
    "        tresh = np.nan  # fallback if no bin matches\n",
    "\n",
    "    if tresh > 1:\n",
    "        tresh = 10\n",
    "    \n",
    "\n",
    "    max_ECI_target = ECI_initial + 0.1\n",
    "        \n",
    "    # Determine the shape of the vectors beforehand\n",
    "\n",
    "\n",
    "    df_eci = eciopt.eci_optimization(target_pct, max_ECI_target, pctRankings, PCTRankings, indices_to_exclude, beta_pct_entry, beta_pct_exit, PHIpp_pct,tresh)\n",
    "\n",
    "    Relatedness_start_baseline = (Relatedness_start - Relatedness_start.min()) / (Relatedness_start.max() - Relatedness_start.min())\n",
    "    pci = PCTRankings['PCI'].values\n",
    "    pci_baseline = (pci - pci.min()) / (pci.max() - pci.min())\n",
    "    baseline_vals = Relatedness_start_baseline * pci_baseline\n",
    "    \n",
    "    df_eci_baseline = eciopt.find_products_criteria(target_pct, max_ECI_target, pctRankings, PCTRankings, indices_to_exclude, beta_pct_entry, beta_pct_exit, baseline_vals)\n",
    " \n",
    "\n",
    "    RCA_start_entry = RCA_start[RCA_start < 1]\n",
    "    Relatedness_start_entry = Relatedness_start[RCA_start < 1]\n",
    "    Relative_relatedness_start_entry = Relative_relatedness_start[RCA_start < 1]\n",
    "\n",
    "    RCA_start_exit = RCA_start[RCA_start >= 1]\n",
    "    Relatedness_start_exit = Relatedness_start[RCA_start >= 1]\n",
    "    Relative_relatedness_start_exit = Relative_relatedness_start[RCA_start >= 1]\n",
    "\n",
    "\n",
    "    Ycp_entry = np.exp((np.log(2)-(beta_pct_entry[0] + beta_pct_entry[2] * np.log(1+RCA_start_entry) + beta_pct_entry[3] * Relatedness_start_entry + beta_pct_entry[4] * Relative_relatedness_start_entry))/beta_pct_entry[1]) - RCA_start_entry - 1\n",
    "    Ycp_exit = np.exp((np.log(2)-(beta_pct_exit[0] + beta_pct_exit[2] * np.log(1+RCA_start_exit) + beta_pct_exit[3] * Relatedness_start_exit + beta_pct_exit[4] * Relative_relatedness_start_exit))/beta_pct_exit[1]) - RCA_start_exit - 1\n",
    "\n",
    "    Ycp = np.full(RCA_start.shape, np.nan)\n",
    "    Ycp[RCA_start < 1] = Ycp_entry\n",
    "    Ycp[RCA_start >= 1] = Ycp_exit\n",
    "\n",
    "    # Save the final column of df_RCA_country and entries of RCA_start\n",
    "    final_added_exports.append(df_eci['Added_vol'])\n",
    "    final_added_exports_baseline.append(df_eci_baseline['Added_vol'])\n",
    "    final_RCA_columns.append(df_eci['RCA_final'])\n",
    "    final_RCA_columns_baseline.append(df_eci_baseline['RCA_final'])    \n",
    "    initial_RCA_starts.append(RCA_start)\n",
    "    initial_Relatedness_starts.append(Relatedness_start)\n",
    "\n",
    "    # Calculate the number of RCA_start > 1\n",
    "    num_rca_start_gt_1 = np.sum(RCA_start > 1)\n",
    "    rca_start_count.append(num_rca_start_gt_1)\n",
    "\n",
    "    # Calculate the new RCA that are > 1 but were not > 1 in RCA_start\n",
    "    final_RCA = df_eci['RCA_final'].to_numpy()\n",
    "    new_rca = np.sum((final_RCA > 1) & (RCA_start <= 1))\n",
    "    new_rca_count.append(new_rca)\n",
    "    added_exports_sum.append(np.sum(df_eci['Added_vol'])) \n",
    "    added_exports_count.append(np.sum(df_eci['Added_vol']>0))\n",
    "    mask = (final_RCA > 1) & (RCA_start <= 1)\n",
    "\n",
    "    # Calculate the mean\n",
    "    mean_relative_relatedness = np.sum(mask * PCTRankings['relative_relatedness']) / new_rca\n",
    "\n",
    "    # Calculate the standard deviation\n",
    "    std_relative_relatedness = np.std(PCTRankings['relative_relatedness'][mask])\n",
    "\n",
    "    # Append the standard deviation if needed\n",
    "    avg_new_connections.append(mean_relative_relatedness)\n",
    "    avg_std_connections.append(std_relative_relatedness)\n",
    "    \n",
    "    avg_new_rca.append(np.sum((df_eci['Added_vol']> 0) * PCTRankings['RCA_start'])/new_rca)\n",
    "    avg_new_Ycp.append(np.sum((df_eci['Added_vol']> 0) * Ycp))\n",
    "\n",
    "\n",
    "\n",
    "    final_RCA_baseline = df_eci_baseline['RCA_final'].to_numpy()\n",
    "    new_rca_baseline = np.sum((final_RCA_baseline > 1) & (RCA_start <= 1))\n",
    "    new_rca_baseline_count.append(new_rca_baseline)\n",
    "    added_exports_baseline_sum.append(np.sum(df_eci_baseline['Added_vol'])) \n",
    "    added_exports_baseline_count.append(np.sum(df_eci_baseline['Added_vol']>0))\n",
    "    mask = (final_RCA_baseline > 1) & (RCA_start <= 1)\n",
    "    \n",
    "    mean_relative_relatedness = np.sum(mask * PCTRankings['relative_relatedness']) / new_rca\n",
    "\n",
    "    # Calculate the standard deviation\n",
    "    std_relative_relatedness = np.std(PCTRankings['relative_relatedness'][mask])\n",
    "\n",
    "    # Append the standard deviation if needed\n",
    "    avg_new_connections_baseline.append(mean_relative_relatedness)\n",
    "    avg_std_connections_baseline.append(std_relative_relatedness)\n",
    "\n",
    "    \n",
    "    avg_new_rca_baseline.append(np.sum((df_eci_baseline['Added_vol']> 0) * PCTRankings['RCA_start'])/new_rca_baseline)\n",
    "    avg_new_Ycp_baseline.append(np.sum(((df_eci_baseline['Added_vol']> 0)) * Ycp))\n",
    "\n",
    "\n",
    "filtered_pct = pctRankings.copy()\n",
    "\n",
    "# Add the results as new columns to filtered_pct DataFrame\n",
    "filtered_pct['Mc_pred'] = rca_start_count.copy()\n",
    "filtered_pct['delta_Mc_pred'] = new_rca_count.copy() \n",
    "filtered_pct['delta_opt_pred'] = added_exports_count.copy() \n",
    "filtered_pct['added_exp_pred'] = added_exports_sum.copy() \n",
    "filtered_pct['avg_new_connections_pred'] = avg_new_connections.copy()\n",
    "filtered_pct['avg_std_connections_pred'] = avg_std_connections.copy()\n",
    "filtered_pct['avg_new_rca_pred'] = avg_new_rca.copy()\n",
    "filtered_pct['avg_new_Ycp_pred'] = avg_new_Ycp.copy()\n",
    "\n",
    "\n",
    "filtered_pct['delta_Mc_baseline'] = new_rca_baseline_count.copy() \n",
    "filtered_pct['delta_opt_baseline'] = added_exports_baseline_count.copy() \n",
    "filtered_pct['added_exp_baseline'] = added_exports_baseline_sum.copy() \n",
    "filtered_pct['avg_new_connections_baseline'] = avg_new_connections_baseline.copy()\n",
    "filtered_pct['avg_std_connections_baseline'] = avg_std_connections_baseline.copy()\n",
    "filtered_pct['avg_new_rca_baseline'] = avg_new_rca_baseline.copy()\n",
    "filtered_pct['avg_new_Ycp_baseline'] = avg_new_Ycp_baseline.copy()\n",
    "\n",
    "## PLOT FIGURE \n",
    "\n",
    "#############################\n",
    "# 1) HELPER FUNCTIONS\n",
    "#############################\n",
    "def linear_fit(x, y):\n",
    "    \"\"\"\n",
    "    Fits a linear model and returns (model, predictions_for_unsorted_x).\n",
    "    \"\"\"\n",
    "    model = LinearRegression()\n",
    "    x_reshaped = np.array(x).reshape(-1, 1)\n",
    "    model.fit(x_reshaped, y)\n",
    "    y_pred = model.predict(x_reshaped)\n",
    "    return model, y_pred\n",
    "\n",
    "def poly_fit(x, y, degree=2):\n",
    "    \"\"\"\n",
    "    Fits a polynomial of given degree. Returns (poly_function, predictions_for_unsorted_x).\n",
    "    poly_function is a np.poly1d object, which you can call on any x-array.\n",
    "    \"\"\"\n",
    "    coeffs = np.polyfit(x, y, deg=degree)\n",
    "    p = np.poly1d(coeffs)\n",
    "    y_pred = p(x)  # predicted y on the *original* (unsorted) x\n",
    "    return p, y_pred\n",
    "\n",
    "#############################\n",
    "# 2) SCATTER + FIT LINES\n",
    "#############################\n",
    "\n",
    "# Colors and markers\n",
    "color_pred = [0, 0.4470, 0.7410]      # Blue\n",
    "color_baseline = [0.8500, 0.3250, 0.0980]   # Yellow\n",
    "marker_baseline = 's'\n",
    "marker_pred = 'o'\n",
    "\n",
    "fig, axes = plt.subplots(2, 2, figsize=(10, 8))\n",
    "\n",
    "scatter_kwargs_baseline = dict(marker=marker_baseline, color=color_baseline,\n",
    "                               s=30, edgecolor='black', linewidth=0.7, alpha=0.4)\n",
    "scatter_kwargs_pred = dict(marker=marker_pred, color=color_pred,\n",
    "                           s=40, edgecolor='black', linewidth=0.7, alpha=0.5)\n",
    "\n",
    "########################################\n",
    "# Plot (a): ECI vs avg_new_rca\n",
    "########################################\n",
    "x_a = filtered_pct['ECI_2022'].values\n",
    "y_a_baseline = filtered_pct['avg_new_rca_baseline'].values\n",
    "y_a_pred = filtered_pct['avg_new_rca_pred'].values\n",
    "\n",
    "# Scatter\n",
    "axes[0, 0].scatter(x_a, y_a_baseline, label='Benchmark', **scatter_kwargs_baseline)\n",
    "axes[0, 0].scatter(x_a, y_a_pred, label='ECI Optimization', **scatter_kwargs_pred)\n",
    "\n",
    "# Polynomial fits (degree=2, example)\n",
    "p_bl, y_a_bl_fit_unsorted = poly_fit(x_a, y_a_baseline, degree=2)\n",
    "p_pred, y_a_pred_fit_unsorted = poly_fit(x_a, y_a_pred, degree=2)\n",
    "\n",
    "# Sort x\n",
    "x_a_sorted = np.sort(x_a)\n",
    "# Predict with polynomial\n",
    "y_a_bl_fit_sorted = p_bl(x_a_sorted)\n",
    "y_a_pred_fit_sorted = p_pred(x_a_sorted)\n",
    "\n",
    "# Plot the lines\n",
    "axes[0, 0].plot(x_a_sorted, y_a_bl_fit_sorted, color=color_baseline,\n",
    "                linestyle='--', alpha=1, linewidth=2)  # Adjust the line width\n",
    "axes[0, 0].plot(x_a_sorted, y_a_pred_fit_sorted, color=color_pred,\n",
    "                linestyle='-', alpha=1, linewidth=2)  # Adjust the line width\n",
    "\n",
    "axes[0, 0].set_title('a', fontsize=16, loc='left')\n",
    "axes[0, 0].set_xlabel('ECI')\n",
    "axes[0, 0].set_ylabel('Average Current RCA\\nof New Products')\n",
    "\n",
    "########################################\n",
    "# Plot (b): ECI vs avg_new_connections\n",
    "########################################\n",
    "x_b = filtered_pct['ECI_2022'].values\n",
    "y_b_baseline = filtered_pct['avg_new_connections_baseline'].values\n",
    "y_b_pred = filtered_pct['avg_new_connections_pred'].values\n",
    "\n",
    "# Scatter\n",
    "axes[0, 1].scatter(x_b, y_b_baseline, label='Benchmark', **scatter_kwargs_baseline)\n",
    "axes[0, 1].scatter(x_b, y_b_pred, label='ECI Optimization', **scatter_kwargs_pred)\n",
    "\n",
    "# Polynomial fits (degree=2, example)\n",
    "p_bl, y_b_bl_fit_unsorted = poly_fit(x_b, y_b_baseline, degree=2)\n",
    "p_pred, y_b_pred_fit_unsorted = poly_fit(x_b, y_b_pred, degree=2)\n",
    "\n",
    "# Sort x\n",
    "x_b_sorted = np.sort(x_b)\n",
    "# Predict with polynomial\n",
    "y_b_bl_fit_sorted = p_bl(x_b_sorted)\n",
    "y_b_pred_fit_sorted = p_pred(x_b_sorted)\n",
    "\n",
    "# Plot lines\n",
    "axes[0, 1].plot(x_b_sorted, y_b_bl_fit_sorted, color=color_baseline,\n",
    "                linestyle='--', alpha=1, linewidth=2)\n",
    "axes[0, 1].plot(x_b_sorted, y_b_pred_fit_sorted, color=color_pred,\n",
    "                linestyle='-', alpha=1, linewidth=2)\n",
    "\n",
    "axes[0, 1].set_title('b', fontsize=16, loc='left')\n",
    "axes[0, 1].set_xlabel('ECI')\n",
    "axes[0, 1].set_ylabel('Average Relative\\nRelatedness of New Activities')\n",
    "\n",
    "########################################\n",
    "# Plot (c): Mc_pred vs delta_Mc\n",
    "########################################\n",
    "x_c = filtered_pct['Mc_pred'].values\n",
    "y_c_baseline = filtered_pct['delta_Mc_baseline'].values\n",
    "y_c_pred = filtered_pct['delta_Mc_pred'].values\n",
    "\n",
    "# Scatter\n",
    "axes[1, 0].scatter(x_c, y_c_baseline, label='Benchmark', **scatter_kwargs_baseline)\n",
    "axes[1, 0].scatter(x_c, y_c_pred, label='ECI Optimization', **scatter_kwargs_pred)\n",
    "\n",
    "# Linear fits\n",
    "model_bl_c, y_c_bl_fit_unsorted = linear_fit(x_c, y_c_baseline)\n",
    "model_pred_c, y_c_pred_fit_unsorted = linear_fit(x_c, y_c_pred)\n",
    "\n",
    "# Sort x\n",
    "x_c_sorted = np.sort(x_c)\n",
    "# Predict\n",
    "y_c_bl_fit_sorted = model_bl_c.predict(x_c_sorted.reshape(-1,1))\n",
    "y_c_pred_fit_sorted = model_pred_c.predict(x_c_sorted.reshape(-1,1))\n",
    "\n",
    "# Plot lines\n",
    "axes[1, 0].plot(x_c_sorted, y_c_bl_fit_sorted, color=color_baseline,\n",
    "                linestyle='--', alpha=1, linewidth=2)\n",
    "axes[1, 0].plot(x_c_sorted, y_c_pred_fit_sorted, color=color_pred,\n",
    "                linestyle='-', alpha=1, linewidth=2)\n",
    "\n",
    "axes[1, 0].set_title('c', fontsize=16, loc='left')\n",
    "axes[1, 0].set_xlabel('Diversity')\n",
    "axes[1, 0].set_ylabel('Number of New Activities')\n",
    "\n",
    "########################################\n",
    "# Plot (d): Mc_pred vs added_exp (log scale)\n",
    "########################################\n",
    "x_d = filtered_pct['Mc_pred'].values\n",
    "y_d_baseline = np.log(1+filtered_pct['added_exp_baseline'].values)\n",
    "y_d_pred = np.log(1+filtered_pct['added_exp_pred'].values)\n",
    "\n",
    "# Scatter\n",
    "axes[1, 1].scatter(x_d, np.exp(y_d_baseline)-1, label='Benchmark', **scatter_kwargs_baseline)\n",
    "axes[1, 1].scatter(x_d, np.exp(y_d_pred)-1, label='ECI Optimization', **scatter_kwargs_pred)\n",
    "\n",
    "# Polynomial fits (example)\n",
    "p_bl_d, y_d_bl_fit_unsorted = poly_fit(x_d, y_d_baseline, degree=1)\n",
    "p_pred_d, y_d_pred_fit_unsorted = poly_fit(x_d, y_d_pred, degree=1)\n",
    "\n",
    "# Sort x\n",
    "x_d_sorted = np.sort(x_d)\n",
    "# Predict\n",
    "y_d_bl_fit_sorted = np.exp(p_bl_d(x_d_sorted))-1\n",
    "y_d_pred_fit_sorted = np.exp(p_pred_d(x_d_sorted))-1\n",
    "\n",
    "# Plot lines\n",
    "axes[1, 1].plot(x_d_sorted, y_d_bl_fit_sorted, color=color_baseline,\n",
    "                linestyle='--', alpha=1, linewidth=3)\n",
    "axes[1, 1].plot(x_d_sorted, y_d_pred_fit_sorted, color=color_pred,\n",
    "                linestyle='-', alpha=1, linewidth=3)\n",
    "\n",
    "axes[1, 1].set_yscale('log')\n",
    "axes[1, 1].set_title('d', fontsize=16, loc='left')\n",
    "axes[1, 1].set_xlabel('Diversity')\n",
    "axes[1, 1].set_ylabel('Added Volume in USD')\n",
    "\n",
    "#############################\n",
    "# INSET AXIS FOR b\n",
    "#############################\n",
    "inset_ax = axes[0, 1].inset_axes([0.6, 0.65, 0.35, 0.3]) \n",
    "inset_ax.scatter(x_b, filtered_pct['avg_std_connections_baseline'],\n",
    "                 color=color_baseline, marker=marker_baseline, s=2, \n",
    "                 edgecolor='black', linewidth=0.7, alpha=0.4, label='Benchmark Std')\n",
    "inset_ax.scatter(x_b, filtered_pct['avg_std_connections_pred'],\n",
    "                 color=color_pred, marker=marker_pred, s=4, \n",
    "                 edgecolor='black', linewidth=0.7, alpha=0.5, label='ECI Opt Std')\n",
    "\n",
    "inset_ax.set_xlabel('ECI', fontsize=8)\n",
    "inset_ax.set_ylabel('St. Dev of\\nRelative Relatedness', fontsize=8)\n",
    "inset_ax.tick_params(axis='both', labelsize=8)\n",
    "\n",
    "x_b = filtered_pct['ECI_2022'].values\n",
    "y_b_baseline = filtered_pct['avg_std_connections_baseline'].values\n",
    "y_b_pred = filtered_pct['avg_std_connections_pred'].values\n",
    "\n",
    "# Polynomial fits (degree=2, example)\n",
    "p_bl, y_b_bl_fit_unsorted = poly_fit(x_b, y_b_baseline, degree=2)\n",
    "p_pred, y_b_pred_fit_unsorted = poly_fit(x_b, y_b_pred, degree=2)\n",
    "\n",
    "# Sort x\n",
    "x_b_sorted = np.sort(x_b)\n",
    "# Predict with polynomial\n",
    "y_b_bl_fit_sorted = p_bl(x_b_sorted)\n",
    "y_b_pred_fit_sorted = p_pred(x_b_sorted)\n",
    "\n",
    "# Plot lines\n",
    "inset_ax.plot(x_b_sorted, y_b_bl_fit_sorted, color=color_baseline,\n",
    "                linestyle='--', alpha=1, linewidth=2)\n",
    "inset_ax.plot(x_b_sorted, y_b_pred_fit_sorted, color=color_pred,\n",
    "                linestyle='-', alpha=1, linewidth=2)\n",
    "\n",
    "#############################\n",
    "# LEGENDS AND LAYOUT\n",
    "#############################\n",
    "fig.tight_layout(rect=[0, 0, 1, 0.96])\n",
    "handles, labels = axes[0, 0].get_legend_handles_labels()\n",
    "fig.legend(handles, labels, loc='upper center', ncol=3, fontsize=12)\n",
    "\n",
    "plt.show()\n",
    "fig.savefig(\"figure_s17.png\", dpi=300)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a3d2515",
   "metadata": {},
   "source": [
    "## CALCULATE GROWTH MODELS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e2edb08",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from scipy.stats import zscore\n",
    "\n",
    "import statsmodels.formula.api as smf\n",
    "from stargazer.stargazer import Stargazer\n",
    "from IPython.display import display, HTML, Markdown\n",
    "\n",
    "import geopandas as gpd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import requests\n",
    "from io import BytesIO\n",
    "\n",
    "import warnings\n",
    "\n",
    "import itertools\n",
    "scaler = MinMaxScaler()\n",
    "\n",
    "# Suppress RuntimeWarnings related to \"All-NaN slice encountered\"\n",
    "warnings.filterwarnings(\"ignore\", category=RuntimeWarning)\n",
    "\n",
    "# -----------------------------------------------\n",
    "# Input\n",
    "# -----------------------------------------------\n",
    "\n",
    "\n",
    "# Set the selected ECI variable (choose one: 'eci_trade', 'eci_tech', 'eci_research')\n",
    "selected_eci = 'eci_trade' \n",
    "\n",
    "\n",
    "# Years of interest\n",
    "initial_year = 1999\n",
    "\n",
    "# fixed effects\n",
    "fixed_effects = 'year'\n",
    "\n",
    "# Define the time period T\n",
    "T = 10\n",
    "\n",
    "# Generate the years with a step of T\n",
    "years = list(range(initial_year, 2019, T))  # Include 2021 by setting stop to 2022\n",
    "\n",
    "\n",
    "# -----------------------------------------------\n",
    "# Download GDP per capita data\n",
    "# -----------------------------------------------\n",
    "\n",
    "#eci_trade_raw = pd.read_csv('data/Data-ECI-Trade.csv')\n",
    "eci_trade_raw = pd.read_csv('data/Data-ECI-Trade.csv')\n",
    "\n",
    "    \n",
    "# Download GDP data\n",
    "gdp_url = \"https://api.worldbank.org/v2/en/indicator/NY.GDP.PCAP.PP.KD?downloadformat=excel\"\n",
    "response = requests.get(gdp_url)\n",
    "gdp_pc = pd.read_excel(BytesIO(response.content), sheet_name=\"Data\", header=3)\n",
    "gdp_pc.rename(columns={'Country Code': 'Country'}, inplace=True)\n",
    "years_data = [col for col in gdp_pc.columns if col.isdigit() and int(col) >= 1995]\n",
    "columns_to_keep = ['Country'] + years_data\n",
    "gdp_pc = gdp_pc[columns_to_keep]\n",
    "\n",
    "# Filter common countries\n",
    "common_countries = eci_trade_raw['Country'].unique()\n",
    "gdp_pc_norm = gdp_pc[gdp_pc['Country'].isin(eci_trade_raw['Country'].unique())].copy()\n",
    "\n",
    "# Reset the index to align rows properly\n",
    "gdp_pc_norm.reset_index(drop=True, inplace=True)\n",
    "eci_trade_raw.reset_index(drop=True, inplace=True)\n",
    "\n",
    "# Align NaN values in gdp_pc with eci_trade for the same years\n",
    "# Apply NaN alignment and z-score normalization\n",
    "for year in years_data:\n",
    "    if year in eci_trade_raw.columns:\n",
    "        gdp_pc_norm.loc[:, year] = gdp_pc_norm[year].where(~eci_trade_raw[year].isna())\n",
    "        gdp_pc_norm.loc[:, year] = zscore(np.log(gdp_pc_norm[year]), nan_policy='omit')  # Handle NaN values properly\n",
    "    \n",
    "\n",
    "pop_url = \"https://api.worldbank.org/v2/en/indicator/SP.POP.TOTL?downloadformat=excel\"\n",
    "response = requests.get(pop_url)\n",
    "pop = pd.read_excel(BytesIO(response.content), sheet_name=\"Data\", header=3)\n",
    "pop.rename(columns={'Country Code': 'Country'}, inplace=True)\n",
    "years_data = [col for col in pop.columns if col.isdigit() and int(col) >= 1995]\n",
    "columns_to_keep = ['Country'] + years_data\n",
    "pop = pop[columns_to_keep]\n",
    "\n",
    "eci_trade = eci_trade_raw.copy()\n",
    "\n",
    "def reshape_df(df, years, value_name):\n",
    "    df.columns = df.columns.astype(str)\n",
    "    df = df[['Country'] + list(map(str, years))]\n",
    "    df = df.melt(id_vars=['Country'], var_name='Year', value_name=value_name)\n",
    "    return df\n",
    "\n",
    "# Calculate growth\n",
    "growth_df = gdp_pc[['Country']].copy()\n",
    "for year in gdp_pc.columns[1:-T]:\n",
    "    future_year = str(int(year) + T)\n",
    "    if future_year in gdp_pc.columns:\n",
    "        growth_df[year] = 100*((gdp_pc[future_year] / gdp_pc[year])** (1 / T)  - 1)\n",
    "    else:\n",
    "        growth_df[year] = np.nan\n",
    "\n",
    "# Reshape ECI\n",
    "eci_trade_reshaped = reshape_df(eci_trade, years, 'eci_trade')\n",
    "gdp_pc_reshaped = reshape_df(gdp_pc, years, 'GDP_Per_Capita')\n",
    "gdp_pc_norm_reshaped = reshape_df(gdp_pc_norm,years,'gdp_pc')\n",
    "growth_reshaped = reshape_df(growth_df, years, 'growth')\n",
    "pop_reshaped = reshape_df(pop, years, 'pop')\n",
    "eci_options_dict = {\n",
    "    'eci_trade': eci_trade_reshaped\n",
    "}\n",
    "\n",
    "\n",
    "selected_eci_df = eci_options_dict[selected_eci]\n",
    "\n",
    "\n",
    "# Merge\n",
    "regression_df = selected_eci_df.merge(gdp_pc_reshaped, on=['Country', 'Year'], how='left') \\\n",
    "                               .merge(gdp_pc_norm_reshaped, on=['Country', 'Year'], how='left') \\\n",
    "                               .merge(growth_reshaped, on=['Country', 'Year'], how='left') \\\n",
    "                               .merge(pop_reshaped, on=['Country', 'Year'], how='left') \n",
    "\n",
    "# Transform\n",
    "#regression_df['gdp_pc'] = np.log(regression_df['GDP_Per_Capita'])\n",
    "\n",
    "target_columns = [\"intensity_trade\", \"intensity_patents\", \"intensity_publications\", \"nat_res\"]\n",
    "\n",
    "# Apply log transformation only if the column exists\n",
    "for column_name in target_columns:\n",
    "    if column_name in regression_df.columns:\n",
    "        regression_df[column_name] = np.log(regression_df[column_name] / regression_df['pop'])  \n",
    "\n",
    "regression_df['pop'] = np.log(regression_df['pop'])\n",
    "\n",
    "\n",
    "regression_df = regression_df.dropna()\n",
    "\n",
    "# -----------------------------------------------\n",
    "# Dynamically Create Regression Formulas\n",
    "# -----------------------------------------------\n",
    "covariate_order = ['eci_trade', 'gdp_pc:eci_trade', 'gdp_pc']\n",
    "\n",
    "# Initialize base fixed effects terms\n",
    "fixed_effects_terms = []\n",
    "if 'time' in fixed_effects:\n",
    "    fixed_effects_terms.append('C(Year)')\n",
    "if 'country' in fixed_effects:\n",
    "    fixed_effects_terms.append('C(Country)')\n",
    "\n",
    "fixed_effects_str = \" + \".join(fixed_effects_terms)  # Combine fixed effects\n",
    "\n",
    "# Initialize a counter for the models\n",
    "model_counter = 0\n",
    "model_formulas = {}\n",
    "\n",
    "# Model 1: Base model\n",
    "model_formulas[model_counter] = f'growth ~ gdp_pc' + (f' + {fixed_effects_str}' if fixed_effects_str else \"\")\n",
    "model_counter += 1\n",
    "\n",
    "# Model 2: Add eci_trade\n",
    "model_formulas[model_counter] = f'growth ~ gdp_pc + eci_trade' + (f' + {fixed_effects_str}' if fixed_effects_str else \"\")\n",
    "model_counter += 1\n",
    "\n",
    "# Model 3: Interaction term\n",
    "model_formulas[model_counter] = f'growth ~ gdp_pc * eci_trade' + (f' + {fixed_effects_str}' if fixed_effects_str else \"\")\n",
    "model_counter += 1\n",
    "\n",
    "\n",
    "models = {}\n",
    "for model_num, formula in model_formulas.items():\n",
    "    model = smf.ols(formula, data=regression_df).fit(cov_type='HC1')\n",
    "    models[model_num] = model\n",
    "\n",
    "stargazer = Stargazer([models[i] for i in sorted(models.keys())])\n",
    "stargazer.covariate_order(covariate_order)\n",
    "\n",
    "# Map the raw variable names to more descriptive labels\n",
    "rename_dict = {\n",
    "    'eci_trade': 'ECI (trade)',\n",
    "    'gdp_pc': 'Log of initial GDP per capita',\n",
    "    'gdp_pc:eci_trade': 'ECI (trade) x Log of initial GDP per capita',\n",
    "    'nat_res': 'Natural resource exports per capita'\n",
    "}\n",
    "\n",
    "stargazer.rename_covariates(rename_dict)\n",
    "# Create a string representing the intervals, e.g. \"1999-2009, 2009-2019\"\n",
    "intervals_str = \", \".join([f\"{y}-{y+T}\" for y in years])\n",
    "\n",
    "# Construct your dependent variable label\n",
    "dep_var_label = f\"Annualized growth of GDP per capita (in PPP constant 2021 USD)<br>({intervals_str})\"\n",
    "\n",
    "# After generating the Stargazer HTML:\n",
    "html_output = stargazer.render_html()\n",
    "\n",
    "# Replace the default dependent variable line with your custom label\n",
    "html_output = html_output.replace(\n",
    "    \"<em>Dependent variable: growth</em>\",\n",
    "    f\"<em>Dependent variable: {dep_var_label}</em>\"\n",
    ")\n",
    "\n",
    "# Display the HTML table in Jupyter Notebook\n",
    "display(HTML(html_output))\n",
    "\n",
    "\n",
    "def predict_growth_rate(target_country, target_eci):\n",
    "    \"\"\"\n",
    "    Predict the expected growth rate for a target country and target ECI value.\n",
    "    \n",
    "    Parameters:\n",
    "    - target_country: str, the target country code\n",
    "    - target_eci: float, the target ECI value\n",
    "    \n",
    "    Returns:\n",
    "    - float, the predicted growth rate\n",
    "    \"\"\"\n",
    "    # Step 1: Get the last year of analysis\n",
    "    last_year = regression_df['Year'].astype(int).max()\n",
    "    \n",
    "    # Step 2: Retrieve GDP per capita for 2022\n",
    "    try:\n",
    "        gdp_pc_2022 = gdp_pc_norm.loc[gdp_pc_norm['Country'] == target_country, '2022'].values[0]\n",
    "    except IndexError:\n",
    "        raise ValueError(f\"GDP per capita data for {target_country} in 2022 is not available.\")\n",
    "    \n",
    "    # Step 3: Prepare predictors\n",
    "    interaction_term = gdp_pc_2022 * target_eci\n",
    "\n",
    "    # Retrieve fixed effects from the model\n",
    "    fixed_effects_last_year = {}\n",
    "    if 'C(Year)' in fixed_effects_str:\n",
    "        fixed_effects_last_year['C(Year)[T.' + str(last_year) + ']'] = models[2].params.get('C(Year)[T.' + str(last_year) + ']', 0)\n",
    "    if 'C(Country)' in fixed_effects_str:\n",
    "        fixed_effects_last_year['C(Country)[T.' + target_country + ']'] = models[2].params.get('C(Country)[T.' + target_country + ']', 0)\n",
    "    \n",
    "    # Step 4: Calculate expected growth rate\n",
    "    expected_growth_rate = models[2].params['Intercept']\n",
    "    expected_growth_rate += models[2].params['eci_trade'] * target_eci\n",
    "    expected_growth_rate += models[2].params['gdp_pc'] * gdp_pc_2022\n",
    "    expected_growth_rate += models[2].params['gdp_pc:eci_trade'] * interaction_term\n",
    "\n",
    "    # Add fixed effects\n",
    "    for effect, value in fixed_effects_last_year.items():\n",
    "        expected_growth_rate += value\n",
    "    \n",
    "    return expected_growth_rate\n",
    "\n",
    "\n",
    "def predict_eci_from_growth_rate(target_country, target_growth_rate):\n",
    "    \"\"\"\n",
    "    Predict the required ECI value to achieve a target growth rate for a given country.\n",
    "    \n",
    "    Parameters:\n",
    "    - target_country: str, the target country code.\n",
    "    - target_growth_rate: float, the desired growth rate.\n",
    "    \n",
    "    Returns:\n",
    "    - float, the required ECI value.\n",
    "    \"\"\"\n",
    "    # Step 1: Get the last year of analysis\n",
    "    last_year = regression_df['Year'].astype(int).max()\n",
    "    \n",
    "    # Step 2: Retrieve GDP per capita for 2022\n",
    "    try:\n",
    "        gdp_pc_2022 = gdp_pc_norm.loc[gdp_pc_norm['Country'] == target_country, '2022'].values[0]\n",
    "    except IndexError:\n",
    "        raise ValueError(f\"GDP per capita data for {target_country} in 2022 is not available.\")\n",
    "    \n",
    "    # Step 3: Calculate fixed effects contributions\n",
    "    fixed_effects_sum = 0\n",
    "    if 'C(Year)' in fixed_effects_str:\n",
    "        fixed_effects_sum += models[2].params.get(f'C(Year)[T.{last_year}]', 0)\n",
    "    if 'C(Country)' in fixed_effects_str:\n",
    "        fixed_effects_sum += models[2].params.get(f'C(Country)[T.{target_country}]', 0)\n",
    "    \n",
    "    # Step 4: Compute constant term A and coefficient term B\n",
    "    A = models[2].params['Intercept'] + models[2].params['gdp_pc'] * gdp_pc_2022 + fixed_effects_sum\n",
    "    B = models[2].params['eci_trade'] + models[2].params['gdp_pc:eci_trade'] * gdp_pc_2022\n",
    "\n",
    "    # Check that B is not zero (which would make the inversion impossible)\n",
    "    if B == 0:\n",
    "        raise ValueError(\"Cannot invert the function because the coefficient for ECI is zero.\")\n",
    "    \n",
    "    # Step 5: Solve for the target ECI\n",
    "    target_eci = (target_growth_rate - A) / B\n",
    "    return target_eci\n",
    "\n",
    "# Save the HTML output\n",
    "with open('table_s1.html', 'w') as f:\n",
    "    f.write(html_output)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c1bcfe0",
   "metadata": {},
   "source": [
    "## FIGURE 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ba941b7-921b-4d37-8cbc-931f6f245e2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "opt_color = [0, 0.4470, 0.7410]\n",
    "\n",
    "target_country = 'tha'\n",
    "\n",
    "\n",
    "\n",
    "fig, axes = plt.subplots(2, 2, figsize=(9, 8))\n",
    "\n",
    "k = np.where(probit_country_data['countries_all'] == target_country)[0]\n",
    "\n",
    "# Check if target_country is a number (int or float)\n",
    "if isinstance(target_country, (int, float)):\n",
    "    X_start = X_countries[target_country - 1, :].T \n",
    "    M_start = M_countries[target_country - 1, :].T \n",
    "else:\n",
    "    country_index = np.where(CountryRankings['Country'] == target_country)[0]\n",
    "    X_start = X_countries[country_index[0], :].T\n",
    "    M_start = M_countries[country_index[0], :].T\n",
    "    \n",
    "X_c_start = np.sum(X_start)\n",
    "X_p_start = np.sum(X_countries, axis=0).T\n",
    "W_p = X_p_start / np.sum(X_p_start)\n",
    "\n",
    "Relatedness_start = probit_country_data['Relatedness_start']\n",
    "Relatedness_start = Relatedness_start[k]\n",
    "Relatedness_start = Relatedness_start.to_numpy()\n",
    "\n",
    "Relative_relatedness_start = probit_country_data['Relative_relatedness_start']\n",
    "Relative_relatedness_start = Relative_relatedness_start[k]\n",
    "Relative_relatedness_start = Relative_relatedness_start.to_numpy()\n",
    "\n",
    "predicted_prob = probit_country_data['predicted_prob']\n",
    "predicted_prob = predicted_prob[k]\n",
    "predicted_prob = predicted_prob.to_numpy()\n",
    "\n",
    "RCA_start = (X_start/X_c_start) / W_p\n",
    "\n",
    "ProductRankings['X_start'] = X_start\n",
    "ProductRankings['Relatedness_start'] = Relatedness_start\n",
    "ProductRankings['Relative_relatedness_start'] = Relative_relatedness_start  \n",
    "ProductRankings['predicted_prob'] = predicted_prob      \n",
    "ProductRankings['X_p_start'] = X_p_start\n",
    "ProductRankings['W_p'] = W_p\n",
    "ProductRankings['RCA_start'] = RCA_start\n",
    "ProductRankings['M_start'] = M_start\n",
    "\n",
    "ProductRankings['relative_relatedness'] = np.where(ProductRankings['RCA_start'] >= 1, np.nan, ProductRankings['Relatedness_start'])\n",
    "\n",
    "# Calculate the z-score for 'Relatedness_start', ignoring NaN values\n",
    "ProductRankings['relative_relatedness'] = zscore(ProductRankings['relative_relatedness'], nan_policy='omit')\n",
    "\n",
    "ProductRankings['world_exp'] = X_p_start\n",
    "ProductRankings['percentile'] = pd.qcut(ProductRankings['world_exp'], 5, labels=False) + 1\n",
    "\n",
    "indices_to_exclude = []\n",
    "\n",
    "\n",
    "# Initialize empty lists to store results\n",
    "max_ECI_targets = []\n",
    "growth_targets = []\n",
    "Y_c_opt_values = []\n",
    "Y_c_baseline_values = []\n",
    "Relative_relatedness_opt_values = []\n",
    "Relative_relatedness_baseline_values = []\n",
    "RCA_opt_values = []\n",
    "RCA_baseline_values = []\n",
    "\n",
    "added_volume = []\n",
    "\n",
    "\n",
    "pci = ProductRankings['PCI'].values\n",
    "# Set the initial ECI and increment\n",
    "ECI_initial = CountryRankings.loc[CountryRankings['Country'] == target_country, 'ECI_not_normalized'].values[0]\n",
    "ECI_initial_norm = CountryRankings.loc[CountryRankings['Country'] == target_country, 'ECI'].values[0]\n",
    "\n",
    "# Find the bin where the ECI falls\n",
    "matching_row = eci_pci_whisker_df[\n",
    "    (eci_pci_whisker_df['ECI_Lower'] <= ECI_initial_norm) &\n",
    "    (eci_pci_whisker_df['ECI_Upper'] > ECI_initial_norm)\n",
    "]\n",
    "\n",
    "# Extract the threshold\n",
    "if not matching_row.empty:\n",
    "    tresh = matching_row['PCI_UpperWhisker'].values[0]\n",
    "else:\n",
    "    tresh = np.nan  # fallback if no bin matches\n",
    "\n",
    "\n",
    "initial_growth_potential = predict_growth_rate(str.upper(target_country),(ECI_initial-mean_for_countries)/sd_for_countries)\n",
    "print(f\"Expected growth of '{target_country.upper()}' without optimization is: {initial_growth_potential:.4f}\")\n",
    "# Example 1\n",
    "max_ECI_target =  predict_eci_from_growth_rate(str.upper(target_country), 3.5)\n",
    "growth_potential_1 = predict_growth_rate(str.upper(target_country),max_ECI_target)\n",
    "print(f\"Expected growth of '{target_country.upper()}' when target ECI is {max_ECI_target:.4f} is : {growth_potential_1:.4f}\")\n",
    "\n",
    "df_eci = eciopt.eci_optimization(target_country, mean_for_countries + sd_for_countries*max_ECI_target, CountryRankings, ProductRankings, indices_to_exclude, beta_country_entry, beta_country_exit, PHIpp_country, tresh)\n",
    "    \n",
    "Relatedness_start_baseline = (Relatedness_start - Relatedness_start.min()) / (Relatedness_start.max() - Relatedness_start.min())\n",
    "pci = ProductRankings['PCI'].values\n",
    "pci_baseline = (pci - pci.min()) / (pci.max() - pci.min())\n",
    "baseline_vals = Relatedness_start_baseline * pci_baseline\n",
    "\n",
    "df_baseline = eciopt.find_products_criteria(target_country, mean_for_countries + sd_for_countries*max_ECI_target, CountryRankings, ProductRankings, indices_to_exclude, beta_country_entry, beta_country_exit, baseline_vals)\n",
    "\n",
    "ProductRankings['ECI_opt_1'] = df_eci['Added_vol'].copy()\n",
    "ProductRankings['ECI_baseline_1'] = df_baseline['Added_vol'].copy()\n",
    "\n",
    "# Perform the calculations for Ycp_entry, Ycp_exit, and Ycp here\n",
    "\n",
    "RCA_start_entry = RCA_start[RCA_start < 1]\n",
    "Relatedness_start_entry = Relatedness_start[RCA_start < 1]\n",
    "Relative_relatedness_start_entry = Relative_relatedness_start[RCA_start < 1]\n",
    "\n",
    "RCA_start_exit = RCA_start[RCA_start >= 1]\n",
    "Relatedness_start_exit = Relatedness_start[RCA_start >= 1]\n",
    "Relative_relatedness_start_exit = Relative_relatedness_start[RCA_start >= 1]\n",
    "\n",
    "Ycp_entry = np.exp((np.log(2) - (beta_country_entry[0] + beta_country_entry[2] * np.log(1 + RCA_start_entry) + beta_country_entry[3] * Relatedness_start_entry + beta_country_entry[4] * Relative_relatedness_start_entry)) / beta_country_entry[1]) - RCA_start_entry - 1\n",
    "Ycp_exit = np.exp((np.log(2) - (beta_country_exit[0] + beta_country_exit[2] * np.log(1 + RCA_start_exit) + beta_country_exit[3] * Relatedness_start_exit + beta_country_exit[4] * Relative_relatedness_start_exit)) / beta_country_exit[1]) - RCA_start_exit - 1\n",
    "\n",
    "Ycp = np.full(RCA_start.shape, np.nan)\n",
    "Ycp[RCA_start < 1] = Ycp_entry\n",
    "Ycp[RCA_start >= 1] = Ycp_exit\n",
    "\n",
    "ProductRankings['Ycp'] = Ycp\n",
    "\n",
    "\n",
    "# Step 3: Define the columns and titles for the first row\n",
    "\n",
    "cleared_data = ProductRankings[(ProductRankings['RCA_start'] < 1) & (ProductRankings['M_start'] < 1)].copy()\n",
    "\n",
    "# Define parameters for scatter plots\n",
    "colors = {\n",
    "    'ECI_opt': [0, 0.4470, 0.7410],\n",
    "    'ECI_baseline': [0.8500, 0.3250, 0.0980]\n",
    "}\n",
    "markers = {\n",
    "    'ECI_opt': 'o',\n",
    "    'ECI_baseline': 's'\n",
    "}\n",
    "enumerations = ['a']\n",
    "legend_labels = ['ECI (baseline)', 'ECI (opt)']\n",
    "alphas = [0.5, 0.5 , 1]\n",
    "\n",
    "cleared_data = ProductRankings[(ProductRankings['RCA_start'] < 1) & (ProductRankings['M_start'] < 1)].copy()\n",
    "\n",
    "\n",
    "# Loop through all subplots in the first row\n",
    "normal_data = cleared_data[(cleared_data['ECI_opt_1'] == 0) & (cleared_data['ECI_baseline_1'] == 0)].copy()\n",
    "normal_data['highlight_color'] = [[200/255, 200/255, 200/255]] * len(normal_data)\n",
    "\n",
    "ax = axes[0, 0]\n",
    "\n",
    "ax.scatter(\n",
    "    normal_data['Ycp'],\n",
    "    normal_data['PCI'],\n",
    "    c=normal_data['highlight_color'],  # Use RGB colors\n",
    "    alpha=0.2\n",
    ")\n",
    "\n",
    "    \n",
    "for j, column_prefix in enumerate(['ECI_baseline', 'ECI_opt']):\n",
    "    # Dynamically filter data\n",
    "    highlight_data = cleared_data[cleared_data[f'{column_prefix}_1'] > 0].copy()\n",
    "    highlight_data['highlight_color'] = [colors[column_prefix]] * len(highlight_data)\n",
    "\n",
    "    # Scatter plot\n",
    "    ax.scatter(\n",
    "        highlight_data['Ycp'],\n",
    "        highlight_data['PCI'],\n",
    "        s=highlight_data['percentile'] * 20,\n",
    "        c=highlight_data['highlight_color'],\n",
    "        edgecolor='black',\n",
    "        linewidth=0.5,\n",
    "        alpha=alphas[j],\n",
    "        label=legend_labels[j],\n",
    "        marker=markers[column_prefix]\n",
    "    )\n",
    "\n",
    "\n",
    "\n",
    "# Add a dashed horizontal line at ECI_initial\n",
    "ax.axhline(y=ECI_initial, color='black', linestyle='--', linewidth=2, alpha=0.4)\n",
    "ax.text(0.24, ECI_initial - 0.05, 'Estimated average PCI',\n",
    "        transform=ax.get_yaxis_transform(), fontsize=10, va='top', ha='center')\n",
    "\n",
    "titles = f\"Thailand\\nTarget growth: {growth_potential_1:.2f} -> Expected ECI = {max_ECI_target:.2f}\"\n",
    "ax.set_title(titles, fontsize=12, ha='center')\n",
    "ax.set_xlabel('Estimated Effort', fontsize=12)\n",
    "ax.set_ylabel('Estimated PCI in 2032', fontsize=12)\n",
    "ax.text(-0.05, 1.15, enumerations[0], transform=ax.transAxes, fontsize=20, fontweight='normal', va='top', ha='left')  # Left-aligned letter\n",
    "\n",
    "\n",
    "# Add a horizontal legend inside the subplot\n",
    "ax.legend(\n",
    "    loc='lower left',  # Position inside the plot on the left\n",
    "    bbox_to_anchor=(0.02, 0.02),  # Moves it slightly inside from the bottom-left\n",
    "    frameon=True,  # Keep a frame around the legend for clarity\n",
    "    fontsize=10\n",
    ")\n",
    "        \n",
    "## 2nd ROW!!!        \n",
    "\n",
    "max_ECI_target = ECI_initial + 0.05 * sd_for_countries\n",
    "increment = 0.05 * sd_for_countries\n",
    "max_ECI_target_final =  predict_eci_from_growth_rate(str.upper(target_country), 3.5)\n",
    "\n",
    "\n",
    "\n",
    "# Loop over max_ECI_target from ECI_initial + 0.01 to 1.8\n",
    "# Define the stopping threshold\n",
    "threshold = max_ECI_target_final * sd_for_countries + mean_for_countries\n",
    "final_run_done = False  # Flag to ensure the last exact run happens\n",
    "\n",
    "while (max_ECI_target - mean_for_countries) / sd_for_countries <= max_ECI_target_final:\n",
    "    # Perform your calculations for df_eci_country, df_eci_rel, df_eci_cplex here\n",
    "    df_eci_country = eciopt.eci_optimization(target_country, max_ECI_target, CountryRankings, ProductRankings, indices_to_exclude, beta_country_entry, beta_country_exit, PHIpp_country,tresh)\n",
    "    growth_potential = predict_growth_rate(str.upper(target_country), \n",
    "                                           (max_ECI_target - mean_for_countries) / sd_for_countries)\n",
    "\n",
    "    df_baseline = eciopt.find_products_criteria(target_country, max_ECI_target, CountryRankings, ProductRankings, \n",
    "                                                indices_to_exclude, beta_country_entry, beta_country_exit, baseline_vals)\n",
    "\n",
    "    added_volume.append(df_eci_country['Added_vol'].values)\n",
    "\n",
    "    # Calculate Y_c_opt, Y_c_rel, and Y_c_cplex\n",
    "    Y_c_opt = np.sum(Ycp * (df_eci_country['Added_vol'] > 0))\n",
    "    Y_c_baseline = np.sum(Ycp * (df_baseline['Added_vol'] > 0))\n",
    "\n",
    "    df_eci_country['Relative_relatedness_start'] = Relative_relatedness_start\n",
    "    Rel_relatedness_opt = df_eci_country[df_eci_country['Added_vol'] > 0]['Relative_relatedness_start'].mean()\n",
    "\n",
    "    df_baseline['Relative_relatedness_start'] = Relative_relatedness_start\n",
    "    Rel_relatedness_baseline = df_baseline[df_baseline['Added_vol'] > 0]['Relative_relatedness_start'].mean()\n",
    "\n",
    "    df_eci_country['RCA_start'] = RCA_start\n",
    "    RCA_opt = df_eci_country[df_eci_country['Added_vol'] > 0]['RCA_start'].mean()\n",
    "\n",
    "    df_baseline['RCA_start'] = RCA_start\n",
    "    RCA_baseline = df_baseline[df_baseline['Added_vol'] > 0]['RCA_start'].mean()\n",
    "\n",
    "    # Append results to the lists\n",
    "    max_ECI_targets.append(max_ECI_target)\n",
    "    growth_targets.append(growth_potential)\n",
    "    Y_c_opt_values.append(Y_c_opt)\n",
    "    Y_c_baseline_values.append(Y_c_baseline)\n",
    "    Relative_relatedness_opt_values.append(Rel_relatedness_opt)\n",
    "    Relative_relatedness_baseline_values.append(Rel_relatedness_baseline)\n",
    "    RCA_opt_values.append(RCA_opt)\n",
    "    RCA_baseline_values.append(RCA_baseline)\n",
    "\n",
    "    print(max_ECI_target)\n",
    "\n",
    "    # Increment max_ECI_target\n",
    "    next_ECI_target = max_ECI_target + increment\n",
    "\n",
    "    # Ensure a final iteration at the exact threshold\n",
    "    if next_ECI_target > threshold and not final_run_done:\n",
    "        max_ECI_target = threshold  # Force one last exact run\n",
    "        final_run_done = True  # Mark that we have done the last run\n",
    "    elif final_run_done:\n",
    "        break  # After the final run, exit the loop\n",
    "    else:\n",
    "        max_ECI_target = next_ECI_target  # Normal increment\n",
    "\n",
    "from matplotlib.colors import Normalize\n",
    "\n",
    "\n",
    "ProductRankings_final = ProductRankings[['Product','Ycp','RCA_start', 'M_start', 'PCI','Relative_relatedness_start', 'Relatedness_start']]\n",
    "\n",
    "# Round max_ECI_targets to two decimals\n",
    "max_ECI_targets_rounded = [f\"ECI_target_{round(val, 2)}\" for val in (max_ECI_targets-mean_for_countries)/sd_for_countries]\n",
    "\n",
    "# Convert the added_volume array to a DataFrame with column names as the rounded ECI target values\n",
    "added_volume_df = pd.DataFrame(added_volume, columns=ProductRankings_final.index).T\n",
    "added_volume_df.columns = max_ECI_targets_rounded\n",
    "\n",
    "# Add the columns to ProductRankings_final\n",
    "ProductRankings_final = pd.concat([ProductRankings_final, added_volume_df], axis=1)\n",
    "\n",
    "# Define columns starting with \"ECI_target_\" in ProductRankings_final\n",
    "eci_target_columns = [col for col in ProductRankings_final.columns if col.startswith(\"ECI_target_\")]\n",
    "\n",
    "# Extract the numerical part from each \"ECI_target_\" column name\n",
    "eci_target_values = [float(col.replace(\"ECI_target_\", \"\")) for col in eci_target_columns]\n",
    "\n",
    "# Create a dictionary to map each column name to its numerical target value\n",
    "eci_target_map = dict(zip(eci_target_columns, eci_target_values))\n",
    "\n",
    "# Define a function to get the first non-zero ECI target value or NaN if all are zero\n",
    "def get_first_non_zero_value(row):\n",
    "    non_zero_columns = row[row > 0].index\n",
    "    return eci_target_map[non_zero_columns[0]] if len(non_zero_columns) > 0 else np.nan\n",
    "\n",
    "# Apply the function to each row of the selected columns and add the result as a new column\n",
    "ProductRankings_final['First_non_zero_ECI_target'] = ProductRankings_final[eci_target_columns].apply(get_first_non_zero_value, axis=1)\n",
    "\n",
    "\n",
    "# Display the updated DataFrame\n",
    "display(ProductRankings_final[ProductRankings_final['First_non_zero_ECI_target'].notna()])\n",
    "\n",
    "\n",
    "filtered_data = ProductRankings_final[(ProductRankings_final['RCA_start'] < 1) & (ProductRankings_final['M_start'] < 1)]\n",
    "\n",
    "\n",
    "# Prepare the DataFrame for display\n",
    "# Select and rename the desired columns\n",
    "table_data = ProductRankings_final[['Product', 'RCA_start', 'Relatedness_start','Relative_relatedness_start', 'PCI', 'First_non_zero_ECI_target']].copy()\n",
    "table_data = table_data.rename(columns={\n",
    "    'RCA_start': r'$\\mathit{R_{cp}}$',\n",
    "    'Relatedness_start': r'$\\omega_{cp}$',\n",
    "    'Relative_relatedness_start': r'$\\tilde{\\omega}_{cp}$',\n",
    "    'PCI': 'PCI',\n",
    "    'First_non_zero_ECI_target': 'Target ECI'\n",
    "})\n",
    "\n",
    "\n",
    "# Format 'R_{cp}', 'Relatedness', and 'PCI' columns to 3 decimal places as strings\n",
    "# Format 'R_{cp}', 'Relatedness', and 'PCI' columns to 3 decimal places as strings\n",
    "table_data[r'$\\mathit{R_{cp}}$'] = table_data[r'$\\mathit{R_{cp}}$'].apply(lambda x: f\"{x:.3f}\")\n",
    "table_data[r'$\\omega_{cp}$'] = table_data[r'$\\omega_{cp}$'].apply(lambda x: f\"{x:.3f}\")\n",
    "\n",
    "table_data[r'$\\tilde{\\omega}_{cp}$'] = table_data[r'$\\tilde{\\omega}_{cp}$'].apply(lambda x: f\"{x:.3f}\")\n",
    "table_data['PCI'] = table_data['PCI'].apply(lambda x: f\"{x:.3f}\")\n",
    "\n",
    "table_data['Target Growth'] = table_data.apply(\n",
    "    lambda row: predict_growth_rate(target_country.upper(), row['Target ECI']), \n",
    "    axis=1\n",
    ")\n",
    "\n",
    "table_data['Target Growth'] = table_data['Target Growth'].apply(lambda x: f\"{x:.3f}\")\n",
    "\n",
    "\n",
    "# Filter to show only rows where 'Target ECI' is not NaN, and sort by 'Target ECI'\n",
    "table_data = table_data[table_data['Target ECI'].notna()].sort_values(by='Target ECI')\n",
    "\n",
    "# Select top 5, middle 5, and bottom 5 rows for display\n",
    "top_5 = table_data.head(5)\n",
    "middle_5 = table_data.iloc[len(table_data) // 2 - 2 : len(table_data) // 2 + 3]\n",
    "bottom_5 = table_data.tail(5)\n",
    "\n",
    "# Concatenate the sections with a row of '...' in between\n",
    "ellipsis_row = pd.DataFrame([['...'] * len(table_data.columns)], columns=table_data.columns)\n",
    "table_display = pd.concat([top_5, ellipsis_row, middle_5, ellipsis_row, bottom_5])\n",
    "\n",
    "\n",
    "# Create a 1x3 subplot figure\n",
    "\n",
    "\n",
    "\n",
    "#fig.text(0.5, 0.98, 'Thailand', ha='center', va='top', fontsize=20, fontweight='normal')\n",
    "#fig.text(0.5, 0.47, 'Mexico', ha='center', va='top', fontsize=20, fontweight='normal')\n",
    "\n",
    "# Third subplot - Table\n",
    "axes[1, 0].axis('off')  # Turn off the axis for the table\n",
    "axes[1, 0].set_title(\"Thailand\", fontsize=12, pad = 20)\n",
    "\n",
    "table = axes[1, 0].table(cellText=table_display.values, colLabels=[r'$\\mathbf{Product}$', \n",
    "                                                               r'$\\mathbf{R_{cp}}$', \n",
    "                                                               r'$\\mathbf{\\omega_{cp}}$',\n",
    "                                                               r'$\\mathbf{\\tilde{\\omega}_{cp}}$', \n",
    "                                                               r'$\\mathbf{PCI}$', \n",
    "                                                               r'$\\mathbf{ECI}$',\n",
    "                                                               r'$\\mathbf{Growth (\\%)}$'],  \n",
    "                      cellLoc='center', loc='center')\n",
    "table.auto_set_font_size(False)\n",
    "table.set_fontsize(10)\n",
    "\n",
    "table.auto_set_column_width(col=list(range(len(table_display.columns))))  # Adjust column width\n",
    "axes[1, 0].text(-0.05, 1.15, 'c', transform=axes[1, 0].transAxes, fontsize=20, \n",
    "             verticalalignment='top', horizontalalignment='right')\n",
    "\n",
    "\n",
    "target_country = 'mex'\n",
    "\n",
    "\n",
    "k = np.where(probit_country_data['countries_all'] == target_country)[0]\n",
    "\n",
    "# Check if target_country is a number (int or float)\n",
    "if isinstance(target_country, (int, float)):\n",
    "    X_start = X_countries[target_country - 1, :].T \n",
    "    M_start = M_countries[target_country - 1, :].T \n",
    "else:\n",
    "    country_index = np.where(CountryRankings['Country'] == target_country)[0]\n",
    "    X_start = X_countries[country_index[0], :].T\n",
    "    M_start = M_countries[country_index[0], :].T\n",
    "    \n",
    "X_c_start = np.sum(X_start)\n",
    "X_p_start = np.sum(X_countries, axis=0).T\n",
    "W_p = X_p_start / np.sum(X_p_start)\n",
    "\n",
    "Relatedness_start = probit_country_data['Relatedness_start']\n",
    "Relatedness_start = Relatedness_start[k]\n",
    "Relatedness_start = Relatedness_start.to_numpy()\n",
    "\n",
    "Relative_relatedness_start = probit_country_data['Relative_relatedness_start']\n",
    "Relative_relatedness_start = Relative_relatedness_start[k]\n",
    "Relative_relatedness_start = Relative_relatedness_start.to_numpy()\n",
    "\n",
    "predicted_prob = probit_country_data['predicted_prob']\n",
    "predicted_prob = predicted_prob[k]\n",
    "predicted_prob = predicted_prob.to_numpy()\n",
    "\n",
    "RCA_start = (X_start/X_c_start) / W_p\n",
    "\n",
    "ProductRankings['X_start'] = X_start\n",
    "ProductRankings['Relatedness_start'] = Relatedness_start\n",
    "ProductRankings['Relative_relatedness_start'] = Relative_relatedness_start  \n",
    "ProductRankings['predicted_prob'] = predicted_prob      \n",
    "ProductRankings['X_p_start'] = X_p_start\n",
    "ProductRankings['W_p'] = W_p\n",
    "ProductRankings['RCA_start'] = RCA_start\n",
    "ProductRankings['M_start'] = M_start\n",
    "\n",
    "ProductRankings['relative_relatedness'] = np.where(ProductRankings['RCA_start'] >= 1, np.nan, ProductRankings['Relatedness_start'])\n",
    "\n",
    "# Calculate the z-score for 'Relatedness_start', ignoring NaN values\n",
    "ProductRankings['relative_relatedness'] = zscore(ProductRankings['relative_relatedness'], nan_policy='omit')\n",
    "\n",
    "ProductRankings['world_exp'] = X_p_start\n",
    "ProductRankings['percentile'] = pd.qcut(ProductRankings['world_exp'], 5, labels=False) + 1\n",
    "\n",
    "indices_to_exclude = []\n",
    "\n",
    "\n",
    "# Initialize empty lists to store results\n",
    "max_ECI_targets = []\n",
    "growth_targets = []\n",
    "Y_c_opt_values = []\n",
    "Y_c_baseline_values = []\n",
    "Relative_relatedness_opt_values = []\n",
    "Relative_relatedness_baseline_values = []\n",
    "RCA_opt_values = []\n",
    "RCA_baseline_values = []\n",
    "\n",
    "added_volume = []\n",
    "\n",
    "\n",
    "pci = ProductRankings['PCI'].values\n",
    "# Set the initial ECI and increment\n",
    "ECI_initial = CountryRankings.loc[CountryRankings['Country'] == target_country, 'ECI_not_normalized'].values[0]\n",
    "\n",
    "ECI_initial_norm = CountryRankings.loc[CountryRankings['Country'] == target_country, 'ECI'].values[0]\n",
    "\n",
    "# Find the bin where the ECI falls\n",
    "matching_row = eci_pci_whisker_df[\n",
    "    (eci_pci_whisker_df['ECI_Lower'] <= ECI_initial_norm) &\n",
    "    (eci_pci_whisker_df['ECI_Upper'] > ECI_initial_norm)\n",
    "]\n",
    "\n",
    "# Extract the threshold\n",
    "if not matching_row.empty:\n",
    "    tresh = matching_row['PCI_UpperWhisker'].values[0]\n",
    "else:\n",
    "    tresh = np.nan  # fallback if no bin matches\n",
    "\n",
    "initial_growth_potential = predict_growth_rate(str.upper(target_country),(ECI_initial-mean_for_countries)/sd_for_countries)\n",
    "print(f\"Expected growth of '{target_country.upper()}' without optimization is: {initial_growth_potential:.4f}\")\n",
    "# Example 1\n",
    "max_ECI_target =  predict_eci_from_growth_rate(str.upper(target_country), 3.5)\n",
    "growth_potential_1 = predict_growth_rate(str.upper(target_country),max_ECI_target)\n",
    "print(f\"Expected growth of '{target_country.upper()}' when target ECI is {max_ECI_target:.4f} is : {growth_potential_1:.4f}\")\n",
    "\n",
    "df_eci = eciopt.eci_optimization(target_country, mean_for_countries + sd_for_countries*max_ECI_target, CountryRankings, ProductRankings, indices_to_exclude, beta_country_entry, beta_country_exit, PHIpp_country,tresh)\n",
    "    \n",
    "Relatedness_start_baseline = (Relatedness_start - Relatedness_start.min()) / (Relatedness_start.max() - Relatedness_start.min())\n",
    "pci = ProductRankings['PCI'].values\n",
    "pci_baseline = (pci - pci.min()) / (pci.max() - pci.min())\n",
    "baseline_vals = Relatedness_start_baseline * pci_baseline\n",
    "\n",
    "df_baseline = eciopt.find_products_criteria(target_country, mean_for_countries + sd_for_countries*max_ECI_target, CountryRankings, ProductRankings, indices_to_exclude, beta_country_entry, beta_country_exit, baseline_vals)\n",
    "\n",
    "ProductRankings['ECI_opt_1'] = df_eci['Added_vol'].copy()\n",
    "ProductRankings['ECI_baseline_1'] = df_baseline['Added_vol'].copy()\n",
    "\n",
    "# Perform the calculations for Ycp_entry, Ycp_exit, and Ycp here\n",
    "\n",
    "RCA_start_entry = RCA_start[RCA_start < 1]\n",
    "Relatedness_start_entry = Relatedness_start[RCA_start < 1]\n",
    "Relative_relatedness_start_entry = Relative_relatedness_start[RCA_start < 1]\n",
    "\n",
    "RCA_start_exit = RCA_start[RCA_start >= 1]\n",
    "Relatedness_start_exit = Relatedness_start[RCA_start >= 1]\n",
    "Relative_relatedness_start_exit = Relative_relatedness_start[RCA_start >= 1]\n",
    "\n",
    "Ycp_entry = np.exp((np.log(2) - (beta_country_entry[0] + beta_country_entry[2] * np.log(1 + RCA_start_entry) + beta_country_entry[3] * Relatedness_start_entry + beta_country_entry[4] * Relative_relatedness_start_entry)) / beta_country_entry[1]) - RCA_start_entry - 1\n",
    "Ycp_exit = np.exp((np.log(2) - (beta_country_exit[0] + beta_country_exit[2] * np.log(1 + RCA_start_exit) + beta_country_exit[3] * Relatedness_start_exit + beta_country_exit[4] * Relative_relatedness_start_exit)) / beta_country_exit[1]) - RCA_start_exit - 1\n",
    "\n",
    "Ycp = np.full(RCA_start.shape, np.nan)\n",
    "Ycp[RCA_start < 1] = Ycp_entry\n",
    "Ycp[RCA_start >= 1] = Ycp_exit\n",
    "\n",
    "ProductRankings['Ycp'] = Ycp\n",
    "\n",
    "\n",
    "# Step 3: Define the columns and titles for the first row\n",
    "\n",
    "cleared_data = ProductRankings[(ProductRankings['RCA_start'] < 1) & (ProductRankings['M_start'] < 1)].copy()\n",
    "\n",
    "# Define parameters for scatter plots\n",
    "colors = {\n",
    "    'ECI_opt': [0, 0.4470, 0.7410],\n",
    "    'ECI_baseline': [0.8500, 0.3250, 0.0980]\n",
    "}\n",
    "markers = {\n",
    "    'ECI_opt': 'o',\n",
    "    'ECI_baseline': 's'\n",
    "}\n",
    "\n",
    "enumerations = ['b']\n",
    "legend_labels = ['ECI (baseline)', 'ECI (opt)']\n",
    "alphas = [0.5, 0.5 , 1]\n",
    "\n",
    "cleared_data = ProductRankings[(ProductRankings['RCA_start'] < 1) & (ProductRankings['M_start'] < 1)].copy()\n",
    "\n",
    "\n",
    "normal_data = cleared_data[(cleared_data[f'ECI_opt_1'] == 0) & (cleared_data[f'ECI_baseline_1'] == 0)].copy()\n",
    "normal_data['highlight_color'] = [[200/255, 200/255, 200/255]] * len(normal_data)\n",
    "\n",
    "ax = axes[0, 1]\n",
    "\n",
    "ax.scatter(\n",
    "    normal_data['Ycp'],\n",
    "    normal_data['PCI'],\n",
    "    c=normal_data['highlight_color'],  # Use RGB colors\n",
    "    alpha=0.2\n",
    ")\n",
    "\n",
    "    \n",
    "for j, column_prefix in enumerate(['ECI_baseline', 'ECI_opt']):\n",
    "    # Dynamically filter data\n",
    "    highlight_data = cleared_data[cleared_data[f'{column_prefix}_1'] > 0].copy()\n",
    "    highlight_data['highlight_color'] = [colors[column_prefix]] * len(highlight_data)\n",
    "\n",
    "    # Scatter plot\n",
    "    ax.scatter(\n",
    "        highlight_data['Ycp'],\n",
    "        highlight_data['PCI'],\n",
    "        s=highlight_data['percentile'] * 20,\n",
    "        c=highlight_data['highlight_color'],\n",
    "        edgecolor='black',\n",
    "        linewidth=0.5,\n",
    "        alpha=alphas[j],\n",
    "        label=legend_labels[j],\n",
    "        marker=markers[column_prefix]\n",
    "    )\n",
    "\n",
    "    # Adjust the legend to be inside the chart on the left\n",
    "\n",
    "\n",
    "# Add a dashed horizontal line at ECI_initial\n",
    "ax.axhline(y=ECI_initial, color='black', linestyle='--', linewidth=2, alpha=0.4)\n",
    "ax.text(0.24, ECI_initial - 0.05, 'Estimated average PCI',\n",
    "        transform=ax.get_yaxis_transform(), fontsize=10, va='top', ha='center')\n",
    "\n",
    "# Add titles and labels\n",
    "titles = f\"Mexico\\nTarget growth: {growth_potential:.2f} -> Expected ECI = {max_ECI_target:.2f}\"\n",
    "ax.set_title(titles, fontsize=12, ha='center')\n",
    "ax.set_xlabel('Estimated Effort', fontsize=12)\n",
    "ax.set_ylabel('Estimated PCI in 2032', fontsize=12)\n",
    "ax.text(-0.05, 1.15, enumerations[0], transform=ax.transAxes, fontsize=20, fontweight='normal', va='top', ha='left')  # Left-aligned letter\n",
    "\n",
    "\n",
    "# Add a horizontal legend inside the subplot\n",
    "ax.legend(\n",
    "    loc='lower left',  # Position inside the plot on the left\n",
    "    bbox_to_anchor=(0.02, 0.02),  # Moves it slightly inside from the bottom-left\n",
    "    frameon=True,  # Keep a frame around the legend for clarity\n",
    "    fontsize=10\n",
    ")\n",
    "        \n",
    "## 2nd ROW!!!        \n",
    "\n",
    "max_ECI_target = ECI_initial + 0.05 * sd_for_countries\n",
    "increment = 0.05 * sd_for_countries\n",
    "max_ECI_target_final =  predict_eci_from_growth_rate(str.upper(target_country), 3.5)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# Loop over max_ECI_target from ECI_initial + 0.01 to 1.8\n",
    "# Define the stopping threshold\n",
    "threshold = max_ECI_target_final * sd_for_countries + mean_for_countries\n",
    "final_run_done = False  # Flag to ensure the last exact run happens\n",
    "\n",
    "while (max_ECI_target - mean_for_countries) / sd_for_countries <= max_ECI_target_final:\n",
    "    # Perform your calculations for df_eci_country, df_eci_rel, df_eci_cplex here\n",
    "    df_eci_country = eciopt.eci_optimization(target_country, max_ECI_target, CountryRankings, ProductRankings, indices_to_exclude, beta_country_entry, beta_country_exit, PHIpp_country, tresh)\n",
    "\n",
    "    growth_potential = predict_growth_rate(str.upper(target_country), \n",
    "                                           (max_ECI_target - mean_for_countries) / sd_for_countries)\n",
    "\n",
    "    df_baseline = eciopt.find_products_criteria(target_country, max_ECI_target, CountryRankings, ProductRankings, \n",
    "                                                indices_to_exclude, beta_country_entry, beta_country_exit, baseline_vals)\n",
    "\n",
    "    added_volume.append(df_eci_country['Added_vol'].values)\n",
    "\n",
    "    # Calculate Y_c_opt, Y_c_rel, and Y_c_cplex\n",
    "    Y_c_opt = np.sum(Ycp * (df_eci_country['Added_vol'] > 0))\n",
    "    Y_c_baseline = np.sum(Ycp * (df_baseline['Added_vol'] > 0))\n",
    "\n",
    "    df_eci_country['Relative_relatedness_start'] = Relative_relatedness_start\n",
    "    Rel_relatedness_opt = df_eci_country[df_eci_country['Added_vol'] > 0]['Relative_relatedness_start'].mean()\n",
    "\n",
    "    df_baseline['Relative_relatedness_start'] = Relative_relatedness_start\n",
    "    Rel_relatedness_baseline = df_baseline[df_baseline['Added_vol'] > 0]['Relative_relatedness_start'].mean()\n",
    "\n",
    "    df_eci_country['RCA_start'] = RCA_start\n",
    "    RCA_opt = df_eci_country[df_eci_country['Added_vol'] > 0]['RCA_start'].mean()\n",
    "\n",
    "    df_baseline['RCA_start'] = RCA_start\n",
    "    RCA_baseline = df_baseline[df_baseline['Added_vol'] > 0]['RCA_start'].mean()\n",
    "\n",
    "    # Append results to the lists\n",
    "    max_ECI_targets.append(max_ECI_target)\n",
    "    growth_targets.append(growth_potential)\n",
    "    Y_c_opt_values.append(Y_c_opt)\n",
    "    Y_c_baseline_values.append(Y_c_baseline)\n",
    "    Relative_relatedness_opt_values.append(Rel_relatedness_opt)\n",
    "    Relative_relatedness_baseline_values.append(Rel_relatedness_baseline)\n",
    "    RCA_opt_values.append(RCA_opt)\n",
    "    RCA_baseline_values.append(RCA_baseline)\n",
    "\n",
    "    print(max_ECI_target)\n",
    "\n",
    "    # Increment max_ECI_target\n",
    "    next_ECI_target = max_ECI_target + increment\n",
    "\n",
    "    # Ensure a final iteration at the exact threshold\n",
    "    if next_ECI_target > threshold and not final_run_done:\n",
    "        max_ECI_target = threshold  # Force one last exact run\n",
    "        final_run_done = True  # Mark that we have done the last run\n",
    "    elif final_run_done:\n",
    "        break  # After the final run, exit the loop\n",
    "    else:\n",
    "        max_ECI_target = next_ECI_target  # Normal increment\n",
    "\n",
    "from matplotlib.colors import Normalize\n",
    "\n",
    "\n",
    "ProductRankings_final = ProductRankings[['Product','Ycp','RCA_start', 'M_start', 'PCI','Relative_relatedness_start', 'Relatedness_start']]\n",
    "\n",
    "# Round max_ECI_targets to two decimals\n",
    "max_ECI_targets_rounded = [f\"ECI_target_{round(val, 2)}\" for val in (max_ECI_targets-mean_for_countries)/sd_for_countries]\n",
    "\n",
    "# Convert the added_volume array to a DataFrame with column names as the rounded ECI target values\n",
    "added_volume_df = pd.DataFrame(added_volume, columns=ProductRankings_final.index).T\n",
    "added_volume_df.columns = max_ECI_targets_rounded\n",
    "\n",
    "# Add the columns to ProductRankings_final\n",
    "ProductRankings_final = pd.concat([ProductRankings_final, added_volume_df], axis=1)\n",
    "\n",
    "# Define columns starting with \"ECI_target_\" in ProductRankings_final\n",
    "eci_target_columns = [col for col in ProductRankings_final.columns if col.startswith(\"ECI_target_\")]\n",
    "\n",
    "# Extract the numerical part from each \"ECI_target_\" column name\n",
    "eci_target_values = [float(col.replace(\"ECI_target_\", \"\")) for col in eci_target_columns]\n",
    "\n",
    "# Create a dictionary to map each column name to its numerical target value\n",
    "eci_target_map = dict(zip(eci_target_columns, eci_target_values))\n",
    "\n",
    "# Define a function to get the first non-zero ECI target value or NaN if all are zero\n",
    "def get_first_non_zero_value(row):\n",
    "    non_zero_columns = row[row > 0].index\n",
    "    return eci_target_map[non_zero_columns[0]] if len(non_zero_columns) > 0 else np.nan\n",
    "\n",
    "# Apply the function to each row of the selected columns and add the result as a new column\n",
    "ProductRankings_final['First_non_zero_ECI_target'] = ProductRankings_final[eci_target_columns].apply(get_first_non_zero_value, axis=1)\n",
    "\n",
    "\n",
    "# Display the updated DataFrame\n",
    "display(ProductRankings_final[ProductRankings_final['First_non_zero_ECI_target'].notna()])\n",
    "\n",
    "\n",
    "filtered_data = ProductRankings_final[(ProductRankings_final['RCA_start'] < 1) & (ProductRankings_final['M_start'] < 1)]\n",
    "\n",
    "\n",
    "\n",
    "# Define colors for each line in the first subplot\n",
    "cplex_color = [0.4940, 0.1840, 0.5560]\n",
    "baseline_color = [0.9290, 0.6940, 0.1250]\n",
    "\n",
    "# Normalize the `First_non_zero_ECI_target` values to the range 0.2 to 0.8\n",
    "norm = Normalize(vmin=filtered_data['First_non_zero_ECI_target'].min(), \n",
    "                 vmax=filtered_data['First_non_zero_ECI_target'].max())\n",
    "alphas = 0.2 + (norm(filtered_data['First_non_zero_ECI_target']) * 0.6)  # Scale to [0.2, 0.8]\n",
    "\n",
    "\n",
    "\n",
    "# Prepare the DataFrame for display\n",
    "# Select and rename the desired columns\n",
    "table_data = ProductRankings_final[['Product', 'RCA_start', 'Relatedness_start','Relative_relatedness_start', 'PCI', 'First_non_zero_ECI_target']].copy()\n",
    "table_data = table_data.rename(columns={\n",
    "    'RCA_start': r'$\\mathit{R_{cp}}$',\n",
    "    'Relatedness_start': r'$\\omega_{cp}$',\n",
    "    'Relative_relatedness_start': r'$\\tilde{\\omega}_{cp}$',\n",
    "    'PCI': 'PCI',\n",
    "    'First_non_zero_ECI_target': 'Target ECI'\n",
    "})\n",
    "\n",
    "# Format 'R_{cp}', 'Relatedness', and 'PCI' columns to 3 decimal places as strings\n",
    "# Format 'R_{cp}', 'Relatedness', and 'PCI' columns to 3 decimal places as strings\n",
    "table_data[r'$\\mathit{R_{cp}}$'] = table_data[r'$\\mathit{R_{cp}}$'].apply(lambda x: f\"{x:.3f}\")\n",
    "table_data[r'$\\omega_{cp}$'] = table_data[r'$\\omega_{cp}$'].apply(lambda x: f\"{x:.3f}\")\n",
    "table_data[r'$\\tilde{\\omega}_{cp}$'] = table_data[r'$\\tilde{\\omega}_{cp}$'].apply(lambda x: f\"{x:.3f}\")\n",
    "table_data['PCI'] = table_data['PCI'].apply(lambda x: f\"{x:.3f}\")\n",
    "\n",
    "table_data['Target Growth'] = table_data.apply(\n",
    "    lambda row: predict_growth_rate(target_country.upper(), row['Target ECI']), \n",
    "    axis=1\n",
    ")\n",
    "\n",
    "table_data['Target Growth'] = table_data['Target Growth'].apply(lambda x: f\"{x:.3f}\")\n",
    "\n",
    "# Filter to show only rows where 'Target ECI' is not NaN, and sort by 'Target ECI'\n",
    "table_data = table_data[table_data['Target ECI'].notna()].sort_values(by='Target ECI')\n",
    "\n",
    "# Select top 5, middle 5, and bottom 5 rows for display\n",
    "top_5 = table_data.head(5)\n",
    "middle_5 = table_data.iloc[len(table_data) // 2 - 2 : len(table_data) // 2 + 3]\n",
    "bottom_5 = table_data.tail(5)\n",
    "\n",
    "# Concatenate the sections with a row of '...' in between\n",
    "ellipsis_row = pd.DataFrame([['...'] * len(table_data.columns)], columns=table_data.columns)\n",
    "table_display = pd.concat([top_5, ellipsis_row, middle_5, ellipsis_row, bottom_5])\n",
    "\n",
    "\n",
    "# Create a 1x3 subplot figure\n",
    "\n",
    "# Third subplot - Table\n",
    "axes[1, 1].axis('off')  # Turn off the axis for the table\n",
    "axes[1, 1].set_title(\"Mexico\", fontsize=12, pad = 20)\n",
    "\n",
    "table = axes[1, 1].table(cellText=table_display.values, colLabels=[r'$\\mathbf{Product}$', \n",
    "                                                               r'$\\mathbf{R_{cp}}$', \n",
    "                                                               r'$\\mathbf{\\omega_{cp}}$',\n",
    "                                                               r'$\\mathbf{\\tilde{\\omega}_{cp}}$', \n",
    "                                                               r'$\\mathbf{PCI}$', \n",
    "                                                               r'$\\mathbf{ECI}$',\n",
    "                                                               r'$\\mathbf{Growth (\\%)}$'],  \n",
    "                      cellLoc='center', loc='center')\n",
    "table.auto_set_font_size(False)\n",
    "table.set_fontsize(10)\n",
    "\n",
    "table.auto_set_column_width(col=list(range(len(table_display.columns))))  # Adjust column width\n",
    "axes[1, 1].text(-0.05, 1.15, 'd', transform=axes[1, 1].transAxes, fontsize=20, \n",
    "             verticalalignment='top', horizontalalignment='right')\n",
    "\n",
    "# Show the plot\n",
    "plt.tight_layout()\n",
    "#fig.subplots_adjust(top=0.90, hspace=0.42)  # Adjust the vertical spacing between rows\n",
    "\n",
    "plt.savefig('figure5.png', dpi=300)\n",
    "\n",
    "plt.show()  \n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
